Unverified Commit 174dd189 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[hotfix] A bunch of fix for SetTransformer (#2658)

* upd

* lint

* upd

* fix

* fix

* upd

* upd

* fix

* warning

* upd

* upd
parent 48a1794f
......@@ -5,6 +5,7 @@ import torch.nn as nn
import numpy as np
from ...backend import pytorch as F
from ...base import dgl_warning
from ...readout import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
softmax_nodes, topk_nodes
......@@ -610,6 +611,35 @@ class Set2Set(nn.Module):
return summary.format(**self.__dict__)
def _gen_mask(lengths_x, lengths_y, max_len_x, max_len_y):
""" Generate binary mask array for given x and y input pairs.
Parameters
----------
lengths_x : Tensor
The int tensor indicates the segment information of x.
lengths_y : Tensor
The int tensor indicates the segment information of y.
max_len_x : int
The maximum element in lengths_x.
max_len_y : int
The maximum element in lengths_y.
Returns
-------
Tensor
the mask tensor with shape (batch_size, 1, max_len_x, max_len_y)
"""
device = lengths_x.device
# x_mask: (batch_size, max_len_x)
x_mask = th.arange(max_len_x, device=device).unsqueeze(0) < lengths_x.unsqueeze(1)
# y_mask: (batch_size, max_len_y)
y_mask = th.arange(max_len_y, device=device).unsqueeze(0) < lengths_y.unsqueeze(1)
# mask: (batch_size, 1, max_len_x, max_len_y)
mask = (x_mask.unsqueeze(-1) & y_mask.unsqueeze(-2)).unsqueeze(1)
return mask
class MultiHeadAttention(nn.Module):
r"""Multi-Head Attention block, used in Transformer, Set Transformer and so on.
......@@ -678,6 +708,9 @@ class MultiHeadAttention(nn.Module):
batch_size = len(lengths_x)
max_len_x = max(lengths_x)
max_len_mem = max(lengths_mem)
device = x.device
lengths_x = th.tensor(lengths_x, dtype=th.int64, device=device)
lengths_mem = th.tensor(lengths_mem, dtype=th.int64, device=device)
queries = self.proj_q(x).view(-1, self.num_heads, self.d_head)
keys = self.proj_k(mem).view(-1, self.num_heads, self.d_head)
......@@ -694,14 +727,15 @@ class MultiHeadAttention(nn.Module):
e = e / np.sqrt(self.d_head)
# generate mask
mask = th.zeros(batch_size, max_len_x, max_len_mem).to(e.device)
for i in range(batch_size):
mask[i, :lengths_x[i], :lengths_mem[i]].fill_(1)
mask = mask.unsqueeze(1)
e.masked_fill_(mask == 0, -float('inf'))
mask = _gen_mask(lengths_x, lengths_mem, max_len_x, max_len_mem)
e = e.masked_fill(mask == 0, -float('inf'))
# apply softmax
alpha = th.softmax(e, dim=-1)
# the following line addresses the NaN issue, see
# https://github.com/dmlc/dgl/issues/2657
alpha = alpha.masked_fill(mask == 0, 0.)
# sum of value weighted by alpha
out = th.einsum('bhxy,byhd->bxhd', alpha, values)
# project to output
......@@ -765,6 +799,8 @@ class InducedSetAttentionBlock(nn.Module):
Parameters
----------
m : int
The number of induced vectors.
d_model : int
The feature size (input and output) in Multi-Head Attention layer.
num_heads : int
......@@ -785,6 +821,9 @@ class InducedSetAttentionBlock(nn.Module):
def __init__(self, m, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
super(InducedSetAttentionBlock, self).__init__()
self.m = m
if m == 1:
dgl_warning("if m is set to 1, the parameters corresponding to query and key "
"projections would not get updated during training.")
self.d_model = d_model
self.inducing_points = nn.Parameter(
th.FloatTensor(m, d_model)
......@@ -832,6 +871,8 @@ class PMALayer(nn.Module):
Parameters
----------
k : int
The number of seed vectors.
d_model : int
The feature size (input and output) in Multi-Head Attention layer.
num_heads : int
......@@ -852,6 +893,9 @@ class PMALayer(nn.Module):
def __init__(self, k, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
super(PMALayer, self).__init__()
self.k = k
if k == 1:
dgl_warning("if k is set to 1, the parameters corresponding to query and key "
"projections would not get updated during training.")
self.d_model = d_model
self.seed_vectors = nn.Parameter(
th.FloatTensor(k, d_model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment