Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
19c7e5af
Unverified
Commit
19c7e5af
authored
Aug 02, 2023
by
Zhiteng Li
Committed by
GitHub
Aug 02, 2023
Browse files
[NN] Add EGT Layer (#6056)
Co-authored-by:
rudongyu
<
ru_dongyu@outlook.com
>
parent
ffd8edeb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
211 additions
and
0 deletions
+211
-0
docs/source/api/python/nn-pytorch.rst
docs/source/api/python/nn-pytorch.rst
+2
-0
python/dgl/nn/pytorch/gt/__init__.py
python/dgl/nn/pytorch/gt/__init__.py
+1
-0
python/dgl/nn/pytorch/gt/egt.py
python/dgl/nn/pytorch/gt/egt.py
+177
-0
tests/python/pytorch/nn/test_nn.py
tests/python/pytorch/nn/test_nn.py
+31
-0
No files found.
docs/source/api/python/nn-pytorch.rst
View file @
19c7e5af
...
...
@@ -135,6 +135,7 @@ Utility Modules
~dgl.nn.pytorch.graph_transformer.DegreeEncoder
~dgl.nn.pytorch.utils.LaplacianPosEnc
~dgl.nn.pytorch.graph_transformer.BiasedMultiheadAttention
~dgl.nn.pytorch.graph_transformer.EGTLayer
~dgl.nn.pytorch.graph_transformer.GraphormerLayer
~dgl.nn.pytorch.graph_transformer.PathEncoder
~dgl.nn.pytorch.graph_transformer.SpatialEncoder
...
...
@@ -165,3 +166,4 @@ Utility Modules for Graph Transformer
~dgl.nn.pytorch.gt.SpatialEncoder3d
~dgl.nn.pytorch.gt.BiasedMHA
~dgl.nn.pytorch.gt.GraphormerLayer
~dgl.nn.pytorch.gt.EGTLayer
python/dgl/nn/pytorch/gt/__init__.py
View file @
19c7e5af
...
...
@@ -2,6 +2,7 @@
from
.biased_mha
import
BiasedMHA
from
.degree_encoder
import
DegreeEncoder
from
.egt
import
EGTLayer
from
.graphormer
import
GraphormerLayer
from
.lap_pos_encoder
import
LapPosEncoder
from
.path_encoder
import
PathEncoder
...
...
python/dgl/nn/pytorch/gt/egt.py
0 → 100644
View file @
19c7e5af
"""EGT Layer"""
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
EGTLayer
(
nn
.
Module
):
r
"""EGTLayer for Edge-augmented Graph Transformer (EGT), as introduced in
`Global Self-Attention as a Replacement for Graph Convolution
Reference `<https://arxiv.org/pdf/2108.03348.pdf>`_
Parameters
----------
feat_size : int
Node feature size.
edge_feat_size : int
Edge feature size.
num_heads : int
Number of attention heads, by which :attr: `feat_size` is divisible.
num_virtual_nodes : int
Number of virtual nodes.
dropout : float, optional
Dropout probability. Default: 0.0.
attn_dropout : float, optional
Attention dropout probability. Default: 0.0.
activation : callable activation layer, optional
Activation function. Default: nn.ELU().
edge_update : bool, optional
Whether to update the edge embedding. Default: True.
Examples
--------
>>> import torch as th
>>> from dgl.nn import EGTLayer
>>> batch_size = 16
>>> num_nodes = 100
>>> feat_size, edge_feat_size = 128, 32
>>> nfeat = th.rand(batch_size, num_nodes, feat_size)
>>> efeat = th.rand(batch_size, num_nodes, num_nodes, edge_feat_size)
>>> net = EGTLayer(
feat_size=feat_size,
edge_feat_size=edge_feat_size,
num_heads=8,
num_virtual_nodes=4,
)
>>> out = net(nfeat, efeat)
"""
def
__init__
(
self
,
feat_size
,
edge_feat_size
,
num_heads
,
num_virtual_nodes
,
dropout
=
0
,
attn_dropout
=
0
,
activation
=
nn
.
ELU
(),
edge_update
=
True
,
):
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
num_virtual_nodes
=
num_virtual_nodes
self
.
edge_update
=
edge_update
assert
(
feat_size
%
num_heads
==
0
),
"feat_size must be divisible by num_heads"
self
.
dot_dim
=
feat_size
//
num_heads
self
.
mha_ln_h
=
nn
.
LayerNorm
(
feat_size
)
self
.
mha_ln_e
=
nn
.
LayerNorm
(
edge_feat_size
)
self
.
edge_input
=
nn
.
Linear
(
edge_feat_size
,
num_heads
)
self
.
qkv_proj
=
nn
.
Linear
(
feat_size
,
feat_size
*
3
)
self
.
gate
=
nn
.
Linear
(
edge_feat_size
,
num_heads
)
self
.
attn_dropout
=
nn
.
Dropout
(
attn_dropout
)
self
.
node_output
=
nn
.
Linear
(
feat_size
,
feat_size
)
self
.
mha_dropout_h
=
nn
.
Dropout
(
dropout
)
self
.
node_ffn
=
nn
.
Sequential
(
nn
.
LayerNorm
(
feat_size
),
nn
.
Linear
(
feat_size
,
feat_size
),
activation
,
nn
.
Linear
(
feat_size
,
feat_size
),
nn
.
Dropout
(
dropout
),
)
if
self
.
edge_update
:
self
.
edge_output
=
nn
.
Linear
(
num_heads
,
edge_feat_size
)
self
.
mha_dropout_e
=
nn
.
Dropout
(
dropout
)
self
.
edge_ffn
=
nn
.
Sequential
(
nn
.
LayerNorm
(
edge_feat_size
),
nn
.
Linear
(
edge_feat_size
,
edge_feat_size
),
activation
,
nn
.
Linear
(
edge_feat_size
,
edge_feat_size
),
nn
.
Dropout
(
dropout
),
)
def
forward
(
self
,
nfeat
,
efeat
,
mask
=
None
):
"""Forward computation. Note: :attr:`nfeat` and :attr:`efeat` should be
padded with embedding of virtual nodes if :attr:`num_virtual_nodes` > 0,
while :attr:`mask` should be padded with `0` values for virtual nodes.
The padding should be put at the beginning.
Parameters
----------
nfeat : torch.Tensor
A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where N
is the sum of the maximum number of nodes and the number of virtual nodes.
efeat : torch.Tensor
Edge embedding used for attention computation and self update.
Shape: (batch_size, N, N, :attr:`edge_feat_size`).
mask : torch.Tensor, optional
The attention mask used for avoiding computation on invalid
positions, where valid positions are indicated by `0` and
invalid positions are indicated by `-inf`.
Shape: (batch_size, N, N). Default: None.
Returns
-------
nfeat : torch.Tensor
The output node embedding. Shape: (batch_size, N, :attr:`feat_size`).
efeat : torch.Tensor, optional
The output edge embedding. Shape: (batch_size, N, N, :attr:`edge_feat_size`).
It is returned only if :attr:`edge_update` is True.
"""
nfeat_r1
=
nfeat
efeat_r1
=
efeat
nfeat_ln
=
self
.
mha_ln_h
(
nfeat
)
efeat_ln
=
self
.
mha_ln_e
(
efeat
)
qkv
=
self
.
qkv_proj
(
nfeat_ln
)
e_bias
=
self
.
edge_input
(
efeat_ln
)
gates
=
self
.
gate
(
efeat_ln
)
bsz
,
N
,
_
=
qkv
.
shape
q_h
,
k_h
,
v_h
=
qkv
.
view
(
bsz
,
N
,
-
1
,
self
.
num_heads
).
split
(
self
.
dot_dim
,
dim
=
2
)
attn_hat
=
torch
.
einsum
(
"bldh,bmdh->blmh"
,
q_h
,
k_h
)
attn_hat
=
attn_hat
.
clamp
(
-
5
,
5
)
+
e_bias
if
mask
is
None
:
gates
=
torch
.
sigmoid
(
gates
)
attn_tild
=
F
.
softmax
(
attn_hat
,
dim
=
2
)
*
gates
else
:
gates
=
torch
.
sigmoid
(
gates
+
mask
.
unsqueeze
(
-
1
))
attn_tild
=
F
.
softmax
(
attn_hat
+
mask
.
unsqueeze
(
-
1
),
dim
=
2
)
*
gates
attn_tild
=
self
.
attn_dropout
(
attn_tild
)
v_attn
=
torch
.
einsum
(
"blmh,bmkh->blkh"
,
attn_tild
,
v_h
)
# Scale the aggregated values by degree.
degrees
=
torch
.
sum
(
gates
,
dim
=
2
,
keepdim
=
True
)
degree_scalers
=
torch
.
log
(
1
+
degrees
)
degree_scalers
[:,
:
self
.
num_virtual_nodes
]
=
1.0
v_attn
=
v_attn
*
degree_scalers
v_attn
=
v_attn
.
reshape
(
bsz
,
N
,
self
.
num_heads
*
self
.
dot_dim
)
nfeat
=
self
.
node_output
(
v_attn
)
nfeat
=
self
.
mha_dropout_h
(
nfeat
)
nfeat
.
add_
(
nfeat_r1
)
nfeat_r2
=
nfeat
nfeat
=
self
.
node_ffn
(
nfeat
)
nfeat
.
add_
(
nfeat_r2
)
if
self
.
edge_update
:
efeat
=
self
.
edge_output
(
attn_hat
)
efeat
=
self
.
mha_dropout_e
(
efeat
)
efeat
.
add_
(
efeat_r1
)
efeat_r2
=
efeat
efeat
=
self
.
edge_ffn
(
efeat
)
efeat
.
add_
(
efeat_r2
)
return
nfeat
,
efeat
return
nfeat
tests/python/pytorch/nn/test_nn.py
View file @
19c7e5af
...
...
@@ -2524,6 +2524,37 @@ def test_BiasedMHA(feat_size, num_heads, bias, attn_bias_type, attn_drop):
assert
out
.
shape
==
(
16
,
100
,
feat_size
)
@
pytest
.
mark
.
parametrize
(
"edge_update"
,
[
True
,
False
])
def
test_EGTLayer
(
edge_update
):
batch_size
=
16
num_nodes
=
100
feat_size
,
edge_feat_size
=
128
,
32
nfeat
=
th
.
rand
(
batch_size
,
num_nodes
,
feat_size
)
efeat
=
th
.
rand
(
batch_size
,
num_nodes
,
num_nodes
,
edge_feat_size
)
mask
=
(
th
.
rand
(
batch_size
,
num_nodes
,
num_nodes
)
<
0.5
)
*
-
1e9
net
=
nn
.
EGTLayer
(
feat_size
=
feat_size
,
edge_feat_size
=
edge_feat_size
,
num_heads
=
8
,
num_virtual_nodes
=
4
,
edge_update
=
edge_update
,
)
if
edge_update
:
out_nfeat
,
out_efeat
=
net
(
nfeat
,
efeat
,
mask
)
assert
out_nfeat
.
shape
==
(
batch_size
,
num_nodes
,
feat_size
)
assert
out_efeat
.
shape
==
(
batch_size
,
num_nodes
,
num_nodes
,
edge_feat_size
,
)
else
:
out_nfeat
=
net
(
nfeat
,
efeat
,
mask
)
assert
out_nfeat
.
shape
==
(
batch_size
,
num_nodes
,
feat_size
)
@
pytest
.
mark
.
parametrize
(
"attn_bias_type"
,
[
"add"
,
"mul"
])
@
pytest
.
mark
.
parametrize
(
"norm_first"
,
[
True
,
False
])
def
test_GraphormerLayer
(
attn_bias_type
,
norm_first
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment