Unverified Commit 19c7e5af authored by Zhiteng Li's avatar Zhiteng Li Committed by GitHub
Browse files

[NN] Add EGT Layer (#6056)


Co-authored-by: default avatarrudongyu <ru_dongyu@outlook.com>
parent ffd8edeb
...@@ -135,6 +135,7 @@ Utility Modules ...@@ -135,6 +135,7 @@ Utility Modules
~dgl.nn.pytorch.graph_transformer.DegreeEncoder ~dgl.nn.pytorch.graph_transformer.DegreeEncoder
~dgl.nn.pytorch.utils.LaplacianPosEnc ~dgl.nn.pytorch.utils.LaplacianPosEnc
~dgl.nn.pytorch.graph_transformer.BiasedMultiheadAttention ~dgl.nn.pytorch.graph_transformer.BiasedMultiheadAttention
~dgl.nn.pytorch.graph_transformer.EGTLayer
~dgl.nn.pytorch.graph_transformer.GraphormerLayer ~dgl.nn.pytorch.graph_transformer.GraphormerLayer
~dgl.nn.pytorch.graph_transformer.PathEncoder ~dgl.nn.pytorch.graph_transformer.PathEncoder
~dgl.nn.pytorch.graph_transformer.SpatialEncoder ~dgl.nn.pytorch.graph_transformer.SpatialEncoder
...@@ -165,3 +166,4 @@ Utility Modules for Graph Transformer ...@@ -165,3 +166,4 @@ Utility Modules for Graph Transformer
~dgl.nn.pytorch.gt.SpatialEncoder3d ~dgl.nn.pytorch.gt.SpatialEncoder3d
~dgl.nn.pytorch.gt.BiasedMHA ~dgl.nn.pytorch.gt.BiasedMHA
~dgl.nn.pytorch.gt.GraphormerLayer ~dgl.nn.pytorch.gt.GraphormerLayer
~dgl.nn.pytorch.gt.EGTLayer
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from .biased_mha import BiasedMHA from .biased_mha import BiasedMHA
from .degree_encoder import DegreeEncoder from .degree_encoder import DegreeEncoder
from .egt import EGTLayer
from .graphormer import GraphormerLayer from .graphormer import GraphormerLayer
from .lap_pos_encoder import LapPosEncoder from .lap_pos_encoder import LapPosEncoder
from .path_encoder import PathEncoder from .path_encoder import PathEncoder
......
"""EGT Layer"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class EGTLayer(nn.Module):
r"""EGTLayer for Edge-augmented Graph Transformer (EGT), as introduced in
`Global Self-Attention as a Replacement for Graph Convolution
Reference `<https://arxiv.org/pdf/2108.03348.pdf>`_
Parameters
----------
feat_size : int
Node feature size.
edge_feat_size : int
Edge feature size.
num_heads : int
Number of attention heads, by which :attr: `feat_size` is divisible.
num_virtual_nodes : int
Number of virtual nodes.
dropout : float, optional
Dropout probability. Default: 0.0.
attn_dropout : float, optional
Attention dropout probability. Default: 0.0.
activation : callable activation layer, optional
Activation function. Default: nn.ELU().
edge_update : bool, optional
Whether to update the edge embedding. Default: True.
Examples
--------
>>> import torch as th
>>> from dgl.nn import EGTLayer
>>> batch_size = 16
>>> num_nodes = 100
>>> feat_size, edge_feat_size = 128, 32
>>> nfeat = th.rand(batch_size, num_nodes, feat_size)
>>> efeat = th.rand(batch_size, num_nodes, num_nodes, edge_feat_size)
>>> net = EGTLayer(
feat_size=feat_size,
edge_feat_size=edge_feat_size,
num_heads=8,
num_virtual_nodes=4,
)
>>> out = net(nfeat, efeat)
"""
def __init__(
self,
feat_size,
edge_feat_size,
num_heads,
num_virtual_nodes,
dropout=0,
attn_dropout=0,
activation=nn.ELU(),
edge_update=True,
):
super().__init__()
self.num_heads = num_heads
self.num_virtual_nodes = num_virtual_nodes
self.edge_update = edge_update
assert (
feat_size % num_heads == 0
), "feat_size must be divisible by num_heads"
self.dot_dim = feat_size // num_heads
self.mha_ln_h = nn.LayerNorm(feat_size)
self.mha_ln_e = nn.LayerNorm(edge_feat_size)
self.edge_input = nn.Linear(edge_feat_size, num_heads)
self.qkv_proj = nn.Linear(feat_size, feat_size * 3)
self.gate = nn.Linear(edge_feat_size, num_heads)
self.attn_dropout = nn.Dropout(attn_dropout)
self.node_output = nn.Linear(feat_size, feat_size)
self.mha_dropout_h = nn.Dropout(dropout)
self.node_ffn = nn.Sequential(
nn.LayerNorm(feat_size),
nn.Linear(feat_size, feat_size),
activation,
nn.Linear(feat_size, feat_size),
nn.Dropout(dropout),
)
if self.edge_update:
self.edge_output = nn.Linear(num_heads, edge_feat_size)
self.mha_dropout_e = nn.Dropout(dropout)
self.edge_ffn = nn.Sequential(
nn.LayerNorm(edge_feat_size),
nn.Linear(edge_feat_size, edge_feat_size),
activation,
nn.Linear(edge_feat_size, edge_feat_size),
nn.Dropout(dropout),
)
def forward(self, nfeat, efeat, mask=None):
"""Forward computation. Note: :attr:`nfeat` and :attr:`efeat` should be
padded with embedding of virtual nodes if :attr:`num_virtual_nodes` > 0,
while :attr:`mask` should be padded with `0` values for virtual nodes.
The padding should be put at the beginning.
Parameters
----------
nfeat : torch.Tensor
A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where N
is the sum of the maximum number of nodes and the number of virtual nodes.
efeat : torch.Tensor
Edge embedding used for attention computation and self update.
Shape: (batch_size, N, N, :attr:`edge_feat_size`).
mask : torch.Tensor, optional
The attention mask used for avoiding computation on invalid
positions, where valid positions are indicated by `0` and
invalid positions are indicated by `-inf`.
Shape: (batch_size, N, N). Default: None.
Returns
-------
nfeat : torch.Tensor
The output node embedding. Shape: (batch_size, N, :attr:`feat_size`).
efeat : torch.Tensor, optional
The output edge embedding. Shape: (batch_size, N, N, :attr:`edge_feat_size`).
It is returned only if :attr:`edge_update` is True.
"""
nfeat_r1 = nfeat
efeat_r1 = efeat
nfeat_ln = self.mha_ln_h(nfeat)
efeat_ln = self.mha_ln_e(efeat)
qkv = self.qkv_proj(nfeat_ln)
e_bias = self.edge_input(efeat_ln)
gates = self.gate(efeat_ln)
bsz, N, _ = qkv.shape
q_h, k_h, v_h = qkv.view(bsz, N, -1, self.num_heads).split(
self.dot_dim, dim=2
)
attn_hat = torch.einsum("bldh,bmdh->blmh", q_h, k_h)
attn_hat = attn_hat.clamp(-5, 5) + e_bias
if mask is None:
gates = torch.sigmoid(gates)
attn_tild = F.softmax(attn_hat, dim=2) * gates
else:
gates = torch.sigmoid(gates + mask.unsqueeze(-1))
attn_tild = F.softmax(attn_hat + mask.unsqueeze(-1), dim=2) * gates
attn_tild = self.attn_dropout(attn_tild)
v_attn = torch.einsum("blmh,bmkh->blkh", attn_tild, v_h)
# Scale the aggregated values by degree.
degrees = torch.sum(gates, dim=2, keepdim=True)
degree_scalers = torch.log(1 + degrees)
degree_scalers[:, : self.num_virtual_nodes] = 1.0
v_attn = v_attn * degree_scalers
v_attn = v_attn.reshape(bsz, N, self.num_heads * self.dot_dim)
nfeat = self.node_output(v_attn)
nfeat = self.mha_dropout_h(nfeat)
nfeat.add_(nfeat_r1)
nfeat_r2 = nfeat
nfeat = self.node_ffn(nfeat)
nfeat.add_(nfeat_r2)
if self.edge_update:
efeat = self.edge_output(attn_hat)
efeat = self.mha_dropout_e(efeat)
efeat.add_(efeat_r1)
efeat_r2 = efeat
efeat = self.edge_ffn(efeat)
efeat.add_(efeat_r2)
return nfeat, efeat
return nfeat
...@@ -2524,6 +2524,37 @@ def test_BiasedMHA(feat_size, num_heads, bias, attn_bias_type, attn_drop): ...@@ -2524,6 +2524,37 @@ def test_BiasedMHA(feat_size, num_heads, bias, attn_bias_type, attn_drop):
assert out.shape == (16, 100, feat_size) assert out.shape == (16, 100, feat_size)
@pytest.mark.parametrize("edge_update", [True, False])
def test_EGTLayer(edge_update):
batch_size = 16
num_nodes = 100
feat_size, edge_feat_size = 128, 32
nfeat = th.rand(batch_size, num_nodes, feat_size)
efeat = th.rand(batch_size, num_nodes, num_nodes, edge_feat_size)
mask = (th.rand(batch_size, num_nodes, num_nodes) < 0.5) * -1e9
net = nn.EGTLayer(
feat_size=feat_size,
edge_feat_size=edge_feat_size,
num_heads=8,
num_virtual_nodes=4,
edge_update=edge_update,
)
if edge_update:
out_nfeat, out_efeat = net(nfeat, efeat, mask)
assert out_nfeat.shape == (batch_size, num_nodes, feat_size)
assert out_efeat.shape == (
batch_size,
num_nodes,
num_nodes,
edge_feat_size,
)
else:
out_nfeat = net(nfeat, efeat, mask)
assert out_nfeat.shape == (batch_size, num_nodes, feat_size)
@pytest.mark.parametrize("attn_bias_type", ["add", "mul"]) @pytest.mark.parametrize("attn_bias_type", ["add", "mul"])
@pytest.mark.parametrize("norm_first", [True, False]) @pytest.mark.parametrize("norm_first", [True, False])
def test_GraphormerLayer(attn_bias_type, norm_first): def test_GraphormerLayer(attn_bias_type, norm_first):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment