Unverified Commit cd817a1a authored by ZhenyuLU_Heliodore's avatar ZhenyuLU_Heliodore Committed by GitHub
Browse files

[NN] Add SpatialEncoder and SpatialEncoder3d (#4991)



* Add SpatialEncoder and SpatialEncoder3d

* Optimize the code execution efficiency

* Fixed certain problems according to Dongyu's suggestions.

* Fix an error about probability of division by zero in PathEcoder; Change certain designs in SpatialEncoder

* Fix a typo

* polish the docstring

* fix doc

* lint
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-14-146.ap-northeast-1.compute.internal>
Co-authored-by: default avatarrudongyu <ru_dongyu@outlook.com>
parent ce378327
...@@ -120,6 +120,8 @@ Utility Modules ...@@ -120,6 +120,8 @@ Utility Modules
~dgl.nn.pytorch.graph_transformer.BiasedMultiheadAttention ~dgl.nn.pytorch.graph_transformer.BiasedMultiheadAttention
~dgl.nn.pytorch.graph_transformer.GraphormerLayer ~dgl.nn.pytorch.graph_transformer.GraphormerLayer
~dgl.nn.pytorch.graph_transformer.PathEncoder ~dgl.nn.pytorch.graph_transformer.PathEncoder
~dgl.nn.pytorch.graph_transformer.SpatialEncoder
~dgl.nn.pytorch.graph_transformer.SpatialEncoder3d
Network Embedding Modules Network Embedding Modules
---------------------------------------- ----------------------------------------
......
"""Torch modules for graph transformers.""" """Torch modules for graph transformers."""
import torch as th import math
import torch.nn as nn
import torch.nn.functional as F import torch as th
from ...convert import to_homogeneous import torch.nn as nn
from ...batch import unbatch import torch.nn.functional as F
from ...transforms import shortest_dist
from ...batch import unbatch
__all__ = [ from ...convert import to_homogeneous
"DegreeEncoder", from ...transforms import shortest_dist
"PathEncoder",
"BiasedMultiheadAttention", __all__ = [
"GraphormerLayer" "DegreeEncoder",
] "BiasedMultiheadAttention",
"PathEncoder",
class DegreeEncoder(nn.Module): "GraphormerLayer",
r"""Degree Encoder, as introduced in "SpatialEncoder",
`Do Transformers Really Perform Bad for Graph Representation? "SpatialEncoder3d",
<https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__ ]
This module is a learnable degree embedding module.
Parameters class DegreeEncoder(nn.Module):
---------- r"""Degree Encoder, as introduced in
max_degree : int `Do Transformers Really Perform Bad for Graph Representation?
Upper bound of degrees to be encoded. <https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__
Each degree will be clamped into the range [0, ``max_degree``]. This module is a learnable degree embedding module.
embedding_dim : int
Output dimension of embedding vectors. Parameters
direction : str, optional ----------
Degrees of which direction to be encoded, max_degree : int
selected from ``in``, ``out`` and ``both``. Upper bound of degrees to be encoded.
``both`` encodes degrees from both directions Each degree will be clamped into the range [0, ``max_degree``].
and output the addition of them. embedding_dim : int
Default : ``both``. Output dimension of embedding vectors.
direction : str, optional
Example Degrees of which direction to be encoded,
------- selected from ``in``, ``out`` and ``both``.
>>> import dgl ``both`` encodes degrees from both directions
>>> from dgl.nn import DegreeEncoder and output the addition of them.
Default : ``both``.
>>> g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> degree_encoder = DegreeEncoder(5, 16) Example
>>> degree_embedding = degree_encoder(g) -------
""" >>> import dgl
>>> from dgl.nn import DegreeEncoder
def __init__(self, max_degree, embedding_dim, direction="both"):
super(DegreeEncoder, self).__init__() >>> g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
self.direction = direction >>> degree_encoder = DegreeEncoder(5, 16)
if direction == "both": >>> degree_embedding = degree_encoder(g)
self.degree_encoder_1 = nn.Embedding( """
max_degree + 1, embedding_dim, padding_idx=0
) def __init__(self, max_degree, embedding_dim, direction="both"):
self.degree_encoder_2 = nn.Embedding( super(DegreeEncoder, self).__init__()
max_degree + 1, embedding_dim, padding_idx=0 self.direction = direction
) if direction == "both":
else: self.degree_encoder_1 = nn.Embedding(
self.degree_encoder = nn.Embedding( max_degree + 1, embedding_dim, padding_idx=0
max_degree + 1, embedding_dim, padding_idx=0 )
) self.degree_encoder_2 = nn.Embedding(
self.max_degree = max_degree max_degree + 1, embedding_dim, padding_idx=0
)
def forward(self, g): else:
""" self.degree_encoder = nn.Embedding(
Parameters max_degree + 1, embedding_dim, padding_idx=0
---------- )
g : DGLGraph self.max_degree = max_degree
A DGLGraph to be encoded. If it is a heterogeneous one,
it will be transformed into a homogeneous one first. def forward(self, g):
"""
Returns Parameters
------- ----------
Tensor g : DGLGraph
Return degree embedding vectors of shape :math:`(N, embedding_dim)`, A DGLGraph to be encoded. If it is a heterogeneous one,
where :math:`N` is th number of nodes in the input graph. it will be transformed into a homogeneous one first.
"""
if len(g.ntypes) > 1 or len(g.etypes) > 1: Returns
g = to_homogeneous(g) -------
in_degree = th.clamp(g.in_degrees(), min=0, max=self.max_degree) Tensor
out_degree = th.clamp(g.out_degrees(), min=0, max=self.max_degree) Return degree embedding vectors of shape :math:`(N, embedding_dim)`,
where :math:`N` is th number of nodes in the input graph.
if self.direction == "in": """
degree_embedding = self.degree_encoder(in_degree) if len(g.ntypes) > 1 or len(g.etypes) > 1:
elif self.direction == "out": g = to_homogeneous(g)
degree_embedding = self.degree_encoder(out_degree) in_degree = th.clamp(g.in_degrees(), min=0, max=self.max_degree)
elif self.direction == "both": out_degree = th.clamp(g.out_degrees(), min=0, max=self.max_degree)
degree_embedding = (self.degree_encoder_1(in_degree)
+ self.degree_encoder_2(out_degree)) if self.direction == "in":
else: degree_embedding = self.degree_encoder(in_degree)
raise ValueError( elif self.direction == "out":
f'Supported direction options: "in", "out" and "both", ' degree_embedding = self.degree_encoder(out_degree)
f'but got {self.direction}' elif self.direction == "both":
) degree_embedding = self.degree_encoder_1(
in_degree
return degree_embedding ) + self.degree_encoder_2(out_degree)
else:
raise ValueError(
class PathEncoder(nn.Module): f'Supported direction options: "in", "out" and "both", '
r"""Path Encoder, as introduced in Edge Encoding of f"but got {self.direction}"
`Do Transformers Really Perform Bad for Graph Representation? )
<https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__ return degree_embedding
This module is a learnable path embedding module and encodes the shortest
path between each pair of nodes as attention bias.
class PathEncoder(nn.Module):
Parameters r"""Path Encoder, as introduced in Edge Encoding of
---------- `Do Transformers Really Perform Bad for Graph Representation?
max_len : int <https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__
Maximum number of edges in each path to be encoded. This module is a learnable path embedding module and encodes the shortest
Exceeding part of each path will be truncated, i.e. path between each pair of nodes as attention bias.
truncating edges with serial number no less than :attr:`max_len`.
feat_dim : int Parameters
Dimension of edge features in the input graph. ----------
num_heads : int, optional max_len : int
Number of attention heads if multi-head attention mechanism is applied. Maximum number of edges in each path to be encoded.
Default : 1. Exceeding part of each path will be truncated, i.e.
truncating edges with serial number no less than :attr:`max_len`.
Examples feat_dim : int
-------- Dimension of edge features in the input graph.
>>> import torch as th num_heads : int, optional
>>> import dgl Number of attention heads if multi-head attention mechanism is applied.
Default : 1.
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1]) Examples
>>> g = dgl.graph((u, v)) --------
>>> edata = th.rand(8, 16) >>> import torch as th
>>> path_encoder = dgl.PathEncoder(2, 16, 8) >>> import dgl
>>> out = path_encoder(g, edata) >>> from dgl.nn import PathEncoder
"""
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
def __init__(self, max_len, feat_dim, num_heads=1): >>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
super().__init__() >>> g = dgl.graph((u, v))
self.max_len = max_len >>> edata = th.rand(8, 16)
self.feat_dim = feat_dim >>> path_encoder = PathEncoder(2, 16, num_heads=8)
self.num_heads = num_heads >>> out = path_encoder(g, edata)
self.embedding_table = nn.Embedding(max_len * num_heads, feat_dim) """
def forward(self, g, edge_feat): def __init__(self, max_len, feat_dim, num_heads=1):
""" super().__init__()
Parameters self.max_len = max_len
---------- self.feat_dim = feat_dim
g : DGLGraph self.num_heads = num_heads
A DGLGraph to be encoded, which must be a homogeneous one. self.embedding_table = nn.Embedding(max_len * num_heads, feat_dim)
edge_feat : torch.Tensor
The input edge feature of shape :math:`(E, feat_dim)`, def forward(self, g, edge_feat):
where :math:`E` is the number of edges in the input graph. """
Parameters
Returns ----------
------- g : DGLGraph
torch.Tensor A DGLGraph to be encoded, which must be a homogeneous one.
Return attention bias as path encoding, edge_feat : torch.Tensor
of shape :math:`(batch_size, N, N, num_heads)`, The input edge feature of shape :math:`(E, feat_dim)`,
where :math:`N` is the maximum number of nodes where :math:`E` is the number of edges in the input graph.
and batch_size is the batch size of the input graph.
""" Returns
-------
g_list = unbatch(g) torch.Tensor
sum_num_edges = 0 Return attention bias as path encoding,
max_num_nodes = th.max(g.batch_num_nodes()) of shape :math:`(batch_size, N, N, num_heads)`,
path_encoding = [] where :math:`N` is the maximum number of nodes
and batch_size is the batch size of the input graph.
for ubg in g_list: """
num_nodes = ubg.num_nodes()
num_edges = ubg.num_edges() g_list = unbatch(g)
edata = edge_feat[sum_num_edges: (sum_num_edges + num_edges)] sum_num_edges = 0
sum_num_edges = sum_num_edges + num_edges max_num_nodes = th.max(g.batch_num_nodes())
edata = th.cat( path_encoding = []
(edata, th.zeros(1, self.feat_dim).to(edata.device)),
dim=0 for ubg in g_list:
) num_nodes = ubg.num_nodes()
_, path = shortest_dist(ubg, root=None, return_paths=True) num_edges = ubg.num_edges()
path_len = min(self.max_len, path.size(dim=2)) edata = edge_feat[sum_num_edges : (sum_num_edges + num_edges)]
sum_num_edges = sum_num_edges + num_edges
# shape: [n, n, l], n = num_nodes, l = path_len edata = th.cat(
shortest_path = path[:, :, 0: path_len] (edata, th.zeros(1, self.feat_dim).to(edata.device)), dim=0
# shape: [n, n] )
shortest_distance = th.clamp( dist, path = shortest_dist(ubg, root=None, return_paths=True)
shortest_dist(ubg, root=None, return_paths=False), path_len = max(1, min(self.max_len, path.size(dim=2)))
min=1,
max=path_len # shape: [n, n, l], n = num_nodes, l = path_len
) shortest_path = path[:, :, 0:path_len]
# shape: [n, n, l, d], d = feat_dim # shape: [n, n]
path_data = edata[shortest_path] shortest_distance = th.clamp(dist, min=1, max=path_len)
# shape: [l, h], h = num_heads # shape: [n, n, l, d], d = feat_dim
embedding_idx = th.reshape( path_data = edata[shortest_path]
th.arange(self.num_heads * path_len), # shape: [l, h, d]
(path_len, self.num_heads) edge_embedding = self.embedding_table.weight[
).to(next(self.embedding_table.parameters()).device) 0 : path_len * self.num_heads
# shape: [d, l, h] ].reshape(path_len, self.num_heads, -1)
edge_embedding = th.permute( # [n, n, l, d] einsum [l, h, d] -> [n, n, h]
self.embedding_table(embedding_idx), (2, 0, 1) # [n, n, h] -> [N, N, h], N = max_num_nodes, padded with -inf
) sub_encoding = th.full(
(max_num_nodes, max_num_nodes, self.num_heads), float("-inf")
# [n, n, l, d] einsum [d, l, h] -> [n, n, h] )
# [n, n, h] -> [N, N, h], N = max_num_nodes, padded with -inf sub_encoding[0:num_nodes, 0:num_nodes] = th.div(
sub_encoding = th.full( th.einsum("xyld,lhd->xyh", path_data, edge_embedding).permute(
(max_num_nodes, max_num_nodes, self.num_heads), 2, 0, 1
float('-inf') ),
) shortest_distance,
sub_encoding[0: num_nodes, 0: num_nodes] = th.div( ).permute(1, 2, 0)
th.einsum( path_encoding.append(sub_encoding)
'xyld,dlh->xyh', path_data, edge_embedding return th.stack(path_encoding, dim=0)
).permute(2, 0, 1),
shortest_distance
).permute(1, 2, 0) class BiasedMultiheadAttention(nn.Module):
path_encoding.append(sub_encoding) r"""Dense Multi-Head Attention Module with Graph Attention Bias.
return th.stack(path_encoding, dim=0) Compute attention between nodes with attention bias obtained from graph
structures, as introduced in `Do Transformers Really Perform Bad for
Graph Representation? <https://arxiv.org/pdf/2106.05234>`__
class BiasedMultiheadAttention(nn.Module):
r"""Dense Multi-Head Attention Module with Graph Attention Bias. .. math::
Compute attention between nodes with attention bias obtained from graph \text{Attn}=\text{softmax}(\dfrac{QK^T}{\sqrt{d}} \circ b)
structures, as introduced in `Do Transformers Really Perform Bad for
Graph Representation? <https://arxiv.org/pdf/2106.05234>`__ :math:`Q` and :math:`K` are feature representation of nodes. :math:`d`
is the corresponding :attr:`feat_size`. :math:`b` is attention bias, which
.. math:: can be additive or multiplicative according to the operator :math:`\circ`.
\text{Attn}=\text{softmax}(\dfrac{QK^T}{\sqrt{d}} \circ b) Parameters
----------
:math:`Q` and :math:`K` are feature representation of nodes. :math:`d` feat_size : int
is the corresponding :attr:`feat_size`. :math:`b` is attention bias, which Feature size.
can be additive or multiplicative according to the operator :math:`\circ`. num_heads : int
Number of attention heads, by which attr:`feat_size` is divisible.
Parameters bias : bool, optional
---------- If True, it uses bias for linear projection. Default: True.
feat_size : int attn_bias_type : str, optional
Feature size. The type of attention bias used for modifying attention. Selected from
num_heads : int 'add' or 'mul'. Default: 'add'.
Number of attention heads, by which attr:`feat_size` is divisible.
bias : bool, optional * 'add' is for additive attention bias.
If True, it uses bias for linear projection. Default: True. * 'mul' is for multiplicative attention bias.
attn_bias_type : str, optional attn_drop : float, optional
The type of attention bias used for modifying attention. Selected from Dropout probability on attention weights. Defalt: 0.1.
'add' or 'mul'. Default: 'add'.
Examples
* 'add' is for additive attention bias. --------
* 'mul' is for multiplicative attention bias. >>> import torch as th
attn_drop : float, optional >>> from dgl.nn import BiasedMultiheadAttention
Dropout probability on attention weights. Defalt: 0.1.
>>> ndata = th.rand(16, 100, 512)
Examples >>> bias = th.rand(16, 100, 100, 8)
-------- >>> net = BiasedMultiheadAttention(feat_size=512, num_heads=8)
>>> import torch as th >>> out = net(ndata, bias)
>>> from dgl.nn import BiasedMultiheadAttention """
>>> ndata = th.rand(16, 100, 512) def __init__(
>>> bias = th.rand(16, 100, 100, 8) self,
>>> net = BiasedMultiheadAttention(feat_size=512, num_heads=8) feat_size,
>>> out = net(ndata, bias) num_heads,
""" bias=True,
attn_bias_type="add",
def __init__(self, feat_size, num_heads, bias=True, attn_bias_type="add", attn_drop=0.1): attn_drop=0.1,
super().__init__() ):
self.feat_size = feat_size super().__init__()
self.num_heads = num_heads self.feat_size = feat_size
self.head_dim = feat_size // num_heads self.num_heads = num_heads
assert ( self.head_dim = feat_size // num_heads
self.head_dim * num_heads == feat_size assert (
), "feat_size must be divisible by num_heads" self.head_dim * num_heads == feat_size
self.scaling = self.head_dim**-0.5 ), "feat_size must be divisible by num_heads"
self.attn_bias_type = attn_bias_type self.scaling = self.head_dim**-0.5
self.attn_bias_type = attn_bias_type
self.q_proj = nn.Linear(feat_size, feat_size, bias=bias)
self.k_proj = nn.Linear(feat_size, feat_size, bias=bias) self.q_proj = nn.Linear(feat_size, feat_size, bias=bias)
self.v_proj = nn.Linear(feat_size, feat_size, bias=bias) self.k_proj = nn.Linear(feat_size, feat_size, bias=bias)
self.out_proj = nn.Linear(feat_size, feat_size, bias=bias) self.v_proj = nn.Linear(feat_size, feat_size, bias=bias)
self.out_proj = nn.Linear(feat_size, feat_size, bias=bias)
self.dropout = nn.Dropout(p=attn_drop)
self.dropout = nn.Dropout(p=attn_drop)
self.reset_parameters()
self.reset_parameters()
def reset_parameters(self):
"""Reset parameters of projection matrices, the same settings as that in Graphormer. def reset_parameters(self):
""" """Reset parameters of projection matrices, the same settings as that in Graphormer."""
nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-0.5) nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-0.5)
nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-0.5) nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-0.5)
nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-0.5) nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-0.5)
nn.init.xavier_uniform_(self.out_proj.weight) nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None: if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0) nn.init.constant_(self.out_proj.bias, 0.0)
def forward(self, ndata, attn_bias=None, attn_mask=None): def forward(self, ndata, attn_bias=None, attn_mask=None):
"""Forward computation. """Forward computation.
Parameters Parameters
---------- ----------
ndata : torch.Tensor ndata : torch.Tensor
A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where
N is the maximum number of nodes. N is the maximum number of nodes.
attn_bias : torch.Tensor, optional attn_bias : torch.Tensor, optional
The attention bias used for attention modification. Shape: The attention bias used for attention modification. Shape:
(batch_size, N, N, :attr:`num_heads`). (batch_size, N, N, :attr:`num_heads`).
attn_mask : torch.Tensor, optional attn_mask : torch.Tensor, optional
The attention mask used for avoiding computation on invalid positions, where The attention mask used for avoiding computation on invalid positions, where
invalid positions are indicated by non-zero values. Shape: (batch_size, N, N). invalid positions are indicated by non-zero values. Shape: (batch_size, N, N).
Returns Returns
------- -------
y : torch.Tensor y : torch.Tensor
The output tensor. Shape: (batch_size, N, :attr:`feat_size`) The output tensor. Shape: (batch_size, N, :attr:`feat_size`)
""" """
q_h = self.q_proj(ndata).transpose(0, 1) q_h = self.q_proj(ndata).transpose(0, 1)
k_h = self.k_proj(ndata).transpose(0, 1) k_h = self.k_proj(ndata).transpose(0, 1)
v_h = self.v_proj(ndata).transpose(0, 1) v_h = self.v_proj(ndata).transpose(0, 1)
bsz, N, _ = ndata.shape bsz, N, _ = ndata.shape
q_h = q_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose(0, 1) / self.scaling q_h = (
k_h = k_h.reshape(N, bsz * self.num_heads, self.head_dim).permute(1, 2, 0) q_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose(0, 1)
v_h = v_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose(0, 1) / self.scaling
)
attn_weights = ( k_h = k_h.reshape(N, bsz * self.num_heads, self.head_dim).permute(
th.bmm(q_h, k_h) 1, 2, 0
.transpose(0, 2) )
.reshape(N, N, bsz, self.num_heads) v_h = v_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose(
.transpose(0, 2) 0, 1
) )
if attn_bias is not None: attn_weights = (
if self.attn_bias_type == "add": th.bmm(q_h, k_h)
attn_weights += attn_bias .transpose(0, 2)
else: .reshape(N, N, bsz, self.num_heads)
attn_weights *= attn_bias .transpose(0, 2)
)
if attn_mask is not None:
attn_weights[attn_mask.to(th.bool)] = float("-inf") if attn_bias is not None:
if self.attn_bias_type == "add":
attn_weights = F.softmax( attn_weights += attn_bias
attn_weights.transpose(0, 2) else:
.reshape(N, N, bsz * self.num_heads) attn_weights *= attn_bias
.transpose(0, 2), if attn_mask is not None:
dim=2, attn_weights[attn_mask.to(th.bool)] = float("-inf")
) attn_weights = F.softmax(
attn_weights.transpose(0, 2)
attn_weights = self.dropout(attn_weights) .reshape(N, N, bsz * self.num_heads)
.transpose(0, 2),
attn = th.bmm(attn_weights, v_h).transpose(0, 1) dim=2,
)
attn = self.out_proj(attn.reshape(N, bsz, self.feat_size).transpose(0, 1))
attn_weights = self.dropout(attn_weights)
return attn
attn = th.bmm(attn_weights, v_h).transpose(0, 1)
class GraphormerLayer(nn.Module): attn = self.out_proj(
r"""Graphormer Layer with Dense Multi-Head Attention, as introduced attn.reshape(N, bsz, self.feat_size).transpose(0, 1)
in `Do Transformers Really Perform Bad for Graph Representation? )
<https://arxiv.org/pdf/2106.05234>`__
return attn
Parameters
----------
feat_size : int class GraphormerLayer(nn.Module):
Feature size. r"""Graphormer Layer with Dense Multi-Head Attention, as introduced
hidden_size : int in `Do Transformers Really Perform Bad for Graph Representation?
Hidden size of feedforward layers. <https://arxiv.org/pdf/2106.05234>`__
num_heads : int
Number of attention heads, by which :attr:`feat_size` is divisible. Parameters
attn_bias_type : str, optional ----------
The type of attention bias used for modifying attention. Selected from feat_size : int
'add' or 'mul'. Default: 'add'. Feature size.
hidden_size : int
* 'add' is for additive attention bias. Hidden size of feedforward layers.
* 'mul' is for multiplicative attention bias. num_heads : int
norm_first : bool, optional Number of attention heads, by which :attr:`feat_size` is divisible.
If True, it performs layer normalization before attention and attn_bias_type : str, optional
feedforward operations. Otherwise, it applies layer normalization The type of attention bias used for modifying attention. Selected from
afterwards. Default: False. 'add' or 'mul'. Default: 'add'.
dropout : float, optional
Dropout probability. Default: 0.1. * 'add' is for additive attention bias.
activation : callable activation layer, optional * 'mul' is for multiplicative attention bias.
Activation function. Default: nn.ReLU(). norm_first : bool, optional
If True, it performs layer normalization before attention and
Examples feedforward operations. Otherwise, it applies layer normalization
-------- afterwards. Default: False.
>>> import torch as th dropout : float, optional
>>> from dgl.nn import GraphormerLayer Dropout probability. Default: 0.1.
activation : callable activation layer, optional
>>> batch_size = 16 Activation function. Default: nn.ReLU().
>>> num_nodes = 100
>>> feat_size = 512 Examples
>>> num_heads = 8 --------
>>> nfeat = th.rand(batch_size, num_nodes, feat_size) >>> import torch as th
>>> bias = th.rand(batch_size, num_nodes, num_nodes, num_heads) >>> from dgl.nn import GraphormerLayer
>>> net = GraphormerLayer(
feat_size=feat_size, >>> batch_size = 16
hidden_size=2048, >>> num_nodes = 100
num_heads=num_heads >>> feat_size = 512
) >>> num_heads = 8
>>> out = net(nfeat, bias) >>> nfeat = th.rand(batch_size, num_nodes, feat_size)
""" >>> bias = th.rand(batch_size, num_nodes, num_nodes, num_heads)
>>> net = GraphormerLayer(
def __init__( feat_size=feat_size,
self, hidden_size=2048,
feat_size, num_heads=num_heads
hidden_size, )
num_heads, >>> out = net(nfeat, bias)
attn_bias_type='add', """
norm_first=False,
dropout=0.1, def __init__(
activation=nn.ReLU() self,
): feat_size,
super().__init__() hidden_size,
num_heads,
self.norm_first = norm_first attn_bias_type="add",
norm_first=False,
self.attn = BiasedMultiheadAttention( dropout=0.1,
feat_size=feat_size, activation=nn.ReLU(),
num_heads=num_heads, ):
attn_bias_type=attn_bias_type, super().__init__()
attn_drop=dropout
) self.norm_first = norm_first
self.ffn = nn.Sequential(
nn.Linear(feat_size, hidden_size), self.attn = BiasedMultiheadAttention(
activation, feat_size=feat_size,
nn.Dropout(p=dropout), num_heads=num_heads,
nn.Linear(hidden_size, feat_size), attn_bias_type=attn_bias_type,
nn.Dropout(p=dropout) attn_drop=dropout,
) )
self.ffn = nn.Sequential(
self.dropout = nn.Dropout(p=dropout) nn.Linear(feat_size, hidden_size),
self.attn_layer_norm = nn.LayerNorm(feat_size) activation,
self.ffn_layer_norm = nn.LayerNorm(feat_size) nn.Dropout(p=dropout),
nn.Linear(hidden_size, feat_size),
def forward(self, nfeat, attn_bias=None, attn_mask=None): nn.Dropout(p=dropout),
"""Forward computation. )
Parameters self.dropout = nn.Dropout(p=dropout)
---------- self.attn_layer_norm = nn.LayerNorm(feat_size)
nfeat : torch.Tensor self.ffn_layer_norm = nn.LayerNorm(feat_size)
A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where
N is the maximum number of nodes. def forward(self, nfeat, attn_bias=None, attn_mask=None):
attn_bias : torch.Tensor, optional """Forward computation.
The attention bias used for attention modification. Shape:
(batch_size, N, N, :attr:`num_heads`). Parameters
attn_mask : torch.Tensor, optional ----------
The attention mask used for avoiding computation on invalid nfeat : torch.Tensor
positions. Shape: (batch_size, N, N). A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where
N is the maximum number of nodes.
Returns attn_bias : torch.Tensor, optional
------- The attention bias used for attention modification. Shape:
y : torch.Tensor (batch_size, N, N, :attr:`num_heads`).
The output tensor. Shape: (batch_size, N, :attr:`feat_size`) attn_mask : torch.Tensor, optional
""" The attention mask used for avoiding computation on invalid
residual = nfeat positions. Shape: (batch_size, N, N).
if self.norm_first:
nfeat = self.attn_layer_norm(nfeat) Returns
nfeat = self.attn(nfeat, attn_bias, attn_mask) -------
nfeat = self.dropout(nfeat) y : torch.Tensor
nfeat = residual + nfeat The output tensor. Shape: (batch_size, N, :attr:`feat_size`)
if not self.norm_first: """
nfeat = self.attn_layer_norm(nfeat) residual = nfeat
if self.norm_first:
residual = nfeat nfeat = self.attn_layer_norm(nfeat)
if self.norm_first: nfeat = self.attn(nfeat, attn_bias, attn_mask)
nfeat = self.ffn_layer_norm(nfeat) nfeat = self.dropout(nfeat)
nfeat = self.ffn(nfeat) nfeat = residual + nfeat
nfeat = residual + nfeat if not self.norm_first:
if not self.norm_first: nfeat = self.attn_layer_norm(nfeat)
nfeat = self.ffn_layer_norm(nfeat) residual = nfeat
if self.norm_first:
return nfeat nfeat = self.ffn_layer_norm(nfeat)
nfeat = self.ffn(nfeat)
nfeat = residual + nfeat
if not self.norm_first:
nfeat = self.ffn_layer_norm(nfeat)
return nfeat
class SpatialEncoder(nn.Module):
r"""Spatial Encoder, as introduced in
`Do Transformers Really Perform Bad for Graph Representation?
<https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__
This module is a learnable spatial embedding module which encodes
the shortest distance between each node pair for attention bias.
Parameters
----------
max_dist : int
Upper bound of the shortest path distance
between each node pair to be encoded.
All distance will be clamped into the range `[0, max_dist]`.
num_heads : int, optional
Number of attention heads if multi-head attention mechanism is applied.
Default : 1.
Examples
--------
>>> import torch as th
>>> import dgl
>>> from dgl.nn import SpatialEncoder
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
>>> g = dgl.graph((u, v))
>>> spatial_encoder = SpatialEncoder(max_dist=2, num_heads=8)
>>> out = spatial_encoder(g)
>>> print(out.shape)
torch.Size([1, 4, 4, 8])
"""
def __init__(self, max_dist, num_heads=1):
super().__init__()
self.max_dist = max_dist
self.num_heads = num_heads
# deactivate node pair between which the distance is -1
self.embedding_table = nn.Embedding(
max_dist + 2, num_heads, padding_idx=0
)
def forward(self, g):
"""
Parameters
----------
g : DGLGraph
A DGLGraph to be encoded, which must be a homogeneous one.
Returns
-------
torch.Tensor
Return attention bias as spatial encoding of shape
:math:`(B, N, N, H)`, where :math:`N` is the maximum number of
nodes, :math:`B` is the batch size of the input graph, and
:math:`H` is :attr:`num_heads`.
"""
device = g.device
g_list = unbatch(g)
max_num_nodes = th.max(g.batch_num_nodes())
spatial_encoding = []
for ubg in g_list:
num_nodes = ubg.num_nodes()
dist = (
th.clamp(
shortest_dist(ubg, root=None, return_paths=False),
min=-1,
max=self.max_dist,
)
+ 1
)
# shape: [n, n, h], n = num_nodes, h = num_heads
dist_embedding = self.embedding_table(dist)
# [n, n, h] -> [N, N, h], N = max_num_nodes, padded with -inf
padded_encoding = th.full(
(max_num_nodes, max_num_nodes, self.num_heads), float("-inf")
).to(device)
padded_encoding[0:num_nodes, 0:num_nodes] = dist_embedding
spatial_encoding.append(padded_encoding)
return th.stack(spatial_encoding, dim=0)
class SpatialEncoder3d(nn.Module):
r"""3D Spatial Encoder, as introduced in
`One Transformer Can Understand Both 2D & 3D Molecular Data
<https://arxiv.org/pdf/2210.01765.pdf>`__
This module encodes pair-wise relation between atom pair :math:`(i,j)` in
the 3D geometric space, according to the Gaussian Basis Kernel function:
:math:`\psi _{(i,j)} ^k = -\frac{1}{\sqrt{2\pi} \lvert \sigma^k \rvert}
\exp{\left ( -\frac{1}{2} \left( \frac{\gamma_{(i,j)} \lvert \lvert r_i -
r_j \rvert \rvert + \beta_{(i,j)} - \mu^k}{\lvert \sigma^k \rvert} \right)
^2 \right)},k=1,...,K,`
where :math:`K` is the number of Gaussian Basis kernels.
:math:`r_i` is the Cartesian coordinate of atom :math:`i`.
:math:`\gamma_{(i,j)}, \beta_{(i,j)}` are learnable scaling factors of
the Gaussian Basis kernels.
Parameters
----------
num_kernels : int
Number of Gaussian Basis Kernels to be applied.
Each Gaussian Basis Kernel contains a learnable kernel center
and a learnable scaling factor.
num_heads : int, optional
Number of attention heads if multi-head attention mechanism is applied.
Default : 1.
max_node_type : int, optional
Maximum number of node types. Default : 1.
Examples
--------
>>> import torch as th
>>> import dgl
>>> from dgl.nn import SpatialEncoder3d
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
>>> g = dgl.graph((u, v))
>>> coordinate = th.rand(4, 3)
>>> node_type = th.tensor([1, 0, 2, 1])
>>> spatial_encoder = SpatialEncoder3d(num_kernels=4,
... num_heads=8,
... max_node_type=3)
>>> out = spatial_encoder(g, coordinate, node_type=node_type)
>>> print(out.shape)
torch.Size([1, 4, 4, 8])
"""
def __init__(self, num_kernels, num_heads=1, max_node_type=1):
super().__init__()
self.num_kernels = num_kernels
self.num_heads = num_heads
self.max_node_type = max_node_type
self.gaussian_means = nn.Embedding(1, num_kernels)
self.gaussian_stds = nn.Embedding(1, num_kernels)
self.linear_layer_1 = nn.Linear(num_kernels, num_kernels)
self.linear_layer_2 = nn.Linear(num_kernels, num_heads)
if max_node_type == 1:
self.mul = nn.Embedding(1, 1)
self.bias = nn.Embedding(1, 1)
else:
self.mul = nn.Embedding(max_node_type + 1, 2)
self.bias = nn.Embedding(max_node_type + 1, 2)
nn.init.uniform_(self.gaussian_means.weight, 0, 3)
nn.init.uniform_(self.gaussian_stds.weight, 0, 3)
nn.init.constant_(self.mul.weight, 0)
nn.init.constant_(self.bias.weight, 1)
def forward(self, g, coord, node_type=None):
"""
Parameters
----------
g : DGLGraph
A DGLGraph to be encoded, which must be a homogeneous one.
coord : torch.Tensor
3D coordinates of nodes in :attr:`g`,
of shape :math:`(N, 3)`,
where :math:`N`: is the number of nodes in :attr:`g`.
node_type : torch.Tensor, optional
Node types of :attr:`g`. Default : None.
* If :attr:`max_node_type` is not 1, :attr:`node_type` needs to
be a tensor in shape :math:`(N,)`. The scaling factors of
each pair of nodes are determined by their node types.
* Otherwise, :attr:`node_type` should be None.
Returns
-------
torch.Tensor
Return attention bias as 3D spatial encoding of shape
:math:`(B, n, n, H)`, where :math:`B` is the batch size, :math:`n`
is the maximum number of nodes in unbatched graphs from :attr:`g`,
and :math:`H` is :attr:`num_heads`.
"""
device = g.device
g_list = unbatch(g)
max_num_nodes = th.max(g.batch_num_nodes())
spatial_encoding = []
sum_num_nodes = 0
if (self.max_node_type == 1) != (node_type is None):
raise ValueError(
"input node_type should be None if and only if "
"max_node_type is 1."
)
for ubg in g_list:
num_nodes = ubg.num_nodes()
sub_coord = coord[sum_num_nodes : sum_num_nodes + num_nodes]
# shape: [n, n], n = num_nodes
euc_dist = th.cdist(sub_coord, sub_coord, p=2)
if node_type is None:
# shape: [1]
mul = self.mul.weight[0, 0]
bias = self.bias.weight[0, 0]
else:
sub_node_type = node_type[
sum_num_nodes : sum_num_nodes + num_nodes
]
mul_embedding = self.mul(sub_node_type)
bias_embedding = self.bias(sub_node_type)
# shape: [n, n]
mul = mul_embedding[:, 0].unsqueeze(-1).repeat(
1, num_nodes
) + mul_embedding[:, 1].unsqueeze(0).repeat(num_nodes, 1)
bias = bias_embedding[:, 0].unsqueeze(-1).repeat(
1, num_nodes
) + bias_embedding[:, 1].unsqueeze(0).repeat(num_nodes, 1)
# shape: [n, n, k], k = num_kernels
scaled_dist = (
(mul * euc_dist + bias)
.repeat(self.num_kernels, 1, 1)
.permute((1, 2, 0))
)
# shape: [k]
gaussian_mean = self.gaussian_means.weight.float().view(-1)
gaussian_var = (
self.gaussian_stds.weight.float().view(-1).abs() + 1e-2
)
# shape: [n, n, k]
gaussian_kernel = (
(
-0.5
* (
th.div(
scaled_dist - gaussian_mean, gaussian_var
).square()
)
)
.exp()
.div(-math.sqrt(2 * math.pi) * gaussian_var)
)
encoding = self.linear_layer_1(gaussian_kernel)
encoding = F.gelu(encoding)
# [n, n, k] -> [n, n, a], a = num_heads
encoding = self.linear_layer_2(encoding)
# [n, n, a] -> [N, N, a], N = max_num_nodes, padded with -inf
padded_encoding = th.full(
(max_num_nodes, max_num_nodes, self.num_heads), float("-inf")
).to(device)
padded_encoding[0:num_nodes, 0:num_nodes] = encoding
spatial_encoding.append(padded_encoding)
sum_num_nodes += num_nodes
return th.stack(spatial_encoding, dim=0)
...@@ -1844,3 +1844,32 @@ def test_PathEncoder(max_len, feat_dim, num_heads): ...@@ -1844,3 +1844,32 @@ def test_PathEncoder(max_len, feat_dim, num_heads):
model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev) model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev)
bias = model(bg, edge_feat) bias = model(bg, edge_feat)
assert bias.shape == (2, 6, 6, num_heads) assert bias.shape == (2, 6, 6, num_heads)
@pytest.mark.parametrize('max_dist', [1, 4])
@pytest.mark.parametrize('num_kernels', [8, 16])
@pytest.mark.parametrize('num_heads', [1, 8])
def test_SpatialEncoder(max_dist, num_kernels, num_heads):
dev = F.ctx()
g1 = dgl.graph((
th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
)).to(dev)
g2 = dgl.graph((
th.tensor([0, 1, 2, 3, 2, 5]),
th.tensor([1, 2, 3, 4, 0, 3])
)).to(dev)
bg = dgl.batch([g1, g2])
ndata = th.rand(bg.num_nodes(), 3).to(dev)
num_nodes = bg.num_nodes()
node_type = th.randint(0, 512, (num_nodes,)).to(dev)
model_1 = nn.SpatialEncoder(max_dist, num_heads=num_heads).to(dev)
model_2 = nn.SpatialEncoder3d(num_kernels, num_heads=num_heads).to(dev)
model_3 = nn.SpatialEncoder3d(
num_kernels, num_heads=num_heads, max_node_type=512
).to(dev)
encoding = model_1(bg)
encoding3d_1 = model_2(bg, ndata)
encoding3d_2 = model_3(bg, ndata, node_type)
assert encoding.shape == (2, 6, 6, num_heads)
assert encoding3d_1.shape == (2, 6, 6, num_heads)
assert encoding3d_2.shape == (2, 6, 6, num_heads)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment