"src/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "fac75e166b4a4a7f84ac5d12c3b8f4ba01cda57b"
Unverified Commit 174dd189 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[hotfix] A bunch of fix for SetTransformer (#2658)

* upd

* lint

* upd

* fix

* fix

* upd

* upd

* fix

* warning

* upd

* upd
parent 48a1794f
...@@ -5,6 +5,7 @@ import torch.nn as nn ...@@ -5,6 +5,7 @@ import torch.nn as nn
import numpy as np import numpy as np
from ...backend import pytorch as F from ...backend import pytorch as F
from ...base import dgl_warning
from ...readout import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\ from ...readout import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
softmax_nodes, topk_nodes softmax_nodes, topk_nodes
...@@ -610,6 +611,35 @@ class Set2Set(nn.Module): ...@@ -610,6 +611,35 @@ class Set2Set(nn.Module):
return summary.format(**self.__dict__) return summary.format(**self.__dict__)
def _gen_mask(lengths_x, lengths_y, max_len_x, max_len_y):
""" Generate binary mask array for given x and y input pairs.
Parameters
----------
lengths_x : Tensor
The int tensor indicates the segment information of x.
lengths_y : Tensor
The int tensor indicates the segment information of y.
max_len_x : int
The maximum element in lengths_x.
max_len_y : int
The maximum element in lengths_y.
Returns
-------
Tensor
the mask tensor with shape (batch_size, 1, max_len_x, max_len_y)
"""
device = lengths_x.device
# x_mask: (batch_size, max_len_x)
x_mask = th.arange(max_len_x, device=device).unsqueeze(0) < lengths_x.unsqueeze(1)
# y_mask: (batch_size, max_len_y)
y_mask = th.arange(max_len_y, device=device).unsqueeze(0) < lengths_y.unsqueeze(1)
# mask: (batch_size, 1, max_len_x, max_len_y)
mask = (x_mask.unsqueeze(-1) & y_mask.unsqueeze(-2)).unsqueeze(1)
return mask
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
r"""Multi-Head Attention block, used in Transformer, Set Transformer and so on. r"""Multi-Head Attention block, used in Transformer, Set Transformer and so on.
...@@ -678,6 +708,9 @@ class MultiHeadAttention(nn.Module): ...@@ -678,6 +708,9 @@ class MultiHeadAttention(nn.Module):
batch_size = len(lengths_x) batch_size = len(lengths_x)
max_len_x = max(lengths_x) max_len_x = max(lengths_x)
max_len_mem = max(lengths_mem) max_len_mem = max(lengths_mem)
device = x.device
lengths_x = th.tensor(lengths_x, dtype=th.int64, device=device)
lengths_mem = th.tensor(lengths_mem, dtype=th.int64, device=device)
queries = self.proj_q(x).view(-1, self.num_heads, self.d_head) queries = self.proj_q(x).view(-1, self.num_heads, self.d_head)
keys = self.proj_k(mem).view(-1, self.num_heads, self.d_head) keys = self.proj_k(mem).view(-1, self.num_heads, self.d_head)
...@@ -694,14 +727,15 @@ class MultiHeadAttention(nn.Module): ...@@ -694,14 +727,15 @@ class MultiHeadAttention(nn.Module):
e = e / np.sqrt(self.d_head) e = e / np.sqrt(self.d_head)
# generate mask # generate mask
mask = th.zeros(batch_size, max_len_x, max_len_mem).to(e.device) mask = _gen_mask(lengths_x, lengths_mem, max_len_x, max_len_mem)
for i in range(batch_size): e = e.masked_fill(mask == 0, -float('inf'))
mask[i, :lengths_x[i], :lengths_mem[i]].fill_(1)
mask = mask.unsqueeze(1)
e.masked_fill_(mask == 0, -float('inf'))
# apply softmax # apply softmax
alpha = th.softmax(e, dim=-1) alpha = th.softmax(e, dim=-1)
# the following line addresses the NaN issue, see
# https://github.com/dmlc/dgl/issues/2657
alpha = alpha.masked_fill(mask == 0, 0.)
# sum of value weighted by alpha # sum of value weighted by alpha
out = th.einsum('bhxy,byhd->bxhd', alpha, values) out = th.einsum('bhxy,byhd->bxhd', alpha, values)
# project to output # project to output
...@@ -765,6 +799,8 @@ class InducedSetAttentionBlock(nn.Module): ...@@ -765,6 +799,8 @@ class InducedSetAttentionBlock(nn.Module):
Parameters Parameters
---------- ----------
m : int
The number of induced vectors.
d_model : int d_model : int
The feature size (input and output) in Multi-Head Attention layer. The feature size (input and output) in Multi-Head Attention layer.
num_heads : int num_heads : int
...@@ -785,6 +821,9 @@ class InducedSetAttentionBlock(nn.Module): ...@@ -785,6 +821,9 @@ class InducedSetAttentionBlock(nn.Module):
def __init__(self, m, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.): def __init__(self, m, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
super(InducedSetAttentionBlock, self).__init__() super(InducedSetAttentionBlock, self).__init__()
self.m = m self.m = m
if m == 1:
dgl_warning("if m is set to 1, the parameters corresponding to query and key "
"projections would not get updated during training.")
self.d_model = d_model self.d_model = d_model
self.inducing_points = nn.Parameter( self.inducing_points = nn.Parameter(
th.FloatTensor(m, d_model) th.FloatTensor(m, d_model)
...@@ -832,6 +871,8 @@ class PMALayer(nn.Module): ...@@ -832,6 +871,8 @@ class PMALayer(nn.Module):
Parameters Parameters
---------- ----------
k : int
The number of seed vectors.
d_model : int d_model : int
The feature size (input and output) in Multi-Head Attention layer. The feature size (input and output) in Multi-Head Attention layer.
num_heads : int num_heads : int
...@@ -852,6 +893,9 @@ class PMALayer(nn.Module): ...@@ -852,6 +893,9 @@ class PMALayer(nn.Module):
def __init__(self, k, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.): def __init__(self, k, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
super(PMALayer, self).__init__() super(PMALayer, self).__init__()
self.k = k self.k = k
if k == 1:
dgl_warning("if k is set to 1, the parameters corresponding to query and key "
"projections would not get updated during training.")
self.d_model = d_model self.d_model = d_model
self.seed_vectors = nn.Parameter( self.seed_vectors = nn.Parameter(
th.FloatTensor(k, d_model) th.FloatTensor(k, d_model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment