Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
bb1f8850
"docs/vscode:/vscode.git/clone" did not exist on "03024f9587d1bf9b577b56c51e745cb3af502f0a"
Unverified
Commit
bb1f8850
authored
Apr 14, 2023
by
rudongyu
Committed by
GitHub
Apr 14, 2023
Browse files
[NN] Refactor the Code Structure of GT (#5100)
parent
4085ec8a
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
975 additions
and
220 deletions
+975
-220
docs/source/api/python/dgl.rst
docs/source/api/python/dgl.rst
+1
-1
docs/source/api/python/nn-pytorch.rst
docs/source/api/python/nn-pytorch.rst
+15
-0
docs/source/api/python/transforms.rst
docs/source/api/python/transforms.rst
+1
-1
python/dgl/nn/pytorch/__init__.py
python/dgl/nn/pytorch/__init__.py
+2
-8
python/dgl/nn/pytorch/gt/__init__.py
python/dgl/nn/pytorch/gt/__init__.py
+8
-0
python/dgl/nn/pytorch/gt/biased_mha.py
python/dgl/nn/pytorch/gt/biased_mha.py
+158
-0
python/dgl/nn/pytorch/gt/degree_encoder.py
python/dgl/nn/pytorch/gt/degree_encoder.py
+92
-0
python/dgl/nn/pytorch/gt/graphormer.py
python/dgl/nn/pytorch/gt/graphormer.py
+125
-0
python/dgl/nn/pytorch/gt/lap_pos_encoder.py
python/dgl/nn/pytorch/gt/lap_pos_encoder.py
+162
-0
python/dgl/nn/pytorch/gt/path_encoder.py
python/dgl/nn/pytorch/gt/path_encoder.py
+105
-0
python/dgl/nn/pytorch/gt/spatial_encoder.py
python/dgl/nn/pytorch/gt/spatial_encoder.py
+254
-0
python/dgl/nn/pytorch/utils.py
python/dgl/nn/pytorch/utils.py
+0
-152
python/dgl/transforms/functional.py
python/dgl/transforms/functional.py
+22
-15
python/dgl/transforms/module.py
python/dgl/transforms/module.py
+21
-14
tests/python/common/transforms/test_transform.py
tests/python/common/transforms/test_transform.py
+4
-4
tests/python/pytorch/nn/test_nn.py
tests/python/pytorch/nn/test_nn.py
+5
-25
No files found.
docs/source/api/python/dgl.rst
View file @
bb1f8850
...
...
@@ -111,7 +111,7 @@ Operators for generating positional encodings of each node.
:toctree: ../../generated
random_walk_pe
lap
lacian
_pe
lap_pe
double_radius_node_labeling
shortest_dist
svd_pe
...
...
docs/source/api/python/nn-pytorch.rst
View file @
bb1f8850
...
...
@@ -146,3 +146,18 @@ Network Embedding Modules
~dgl.nn.pytorch.DeepWalk
~dgl.nn.pytorch.MetaPath2Vec
Utility Modules for Graph Transformer
----------------------------------------
.. autosummary::
:toctree: ../../generated/
:nosignatures:
:template: classtemplate.rst
~dgl.nn.pytorch.gt.DegreeEncoder
~dgl.nn.pytorch.gt.LapPosEncoder
~dgl.nn.pytorch.gt.PathEncoder
~dgl.nn.pytorch.gt.SpatialEncoder
~dgl.nn.pytorch.gt.SpatialEncoder3d
~dgl.nn.pytorch.gt.BiasedMHA
~dgl.nn.pytorch.gt.GraphormerLayer
docs/source/api/python/transforms.rst
View file @
bb1f8850
...
...
@@ -29,7 +29,7 @@ dgl.transforms
DropEdge
AddEdge
RandomWalkPE
Lap
lacian
PE
LapPE
FeatMask
RowFeatNormalizer
SIGNDiffusion
...
...
python/dgl/nn/pytorch/__init__.py
View file @
bb1f8850
...
...
@@ -8,12 +8,6 @@ from .softmax import *
from
.factory
import
*
from
.hetero
import
*
from
.sparse_emb
import
NodeEmbedding
from
.utils
import
(
JumpingKnowledge
,
LabelPropagation
,
LaplacianPosEnc
,
Sequential
,
WeightBasis
,
)
from
.utils
import
JumpingKnowledge
,
LabelPropagation
,
Sequential
,
WeightBasis
from
.network_emb
import
*
from
.g
raph_transformer
import
*
from
.g
t
import
*
python/dgl/nn/pytorch/gt/__init__.py
0 → 100644
View file @
bb1f8850
"""Torch modules for Graph Transformer."""
from
.biased_mha
import
BiasedMHA
from
.degree_encoder
import
DegreeEncoder
from
.graphormer
import
GraphormerLayer
from
.lap_pos_encoder
import
LapPosEncoder
from
.path_encoder
import
PathEncoder
from
.spatial_encoder
import
SpatialEncoder
,
SpatialEncoder3d
python/dgl/nn/pytorch/gt/biased_mha.py
0 → 100644
View file @
bb1f8850
"""Biased Multi-head Attention"""
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
BiasedMHA
(
nn
.
Module
):
r
"""Dense Multi-Head Attention Module with Graph Attention Bias.
Compute attention between nodes with attention bias obtained from graph
structures, as introduced in `Do Transformers Really Perform Bad for
Graph Representation? <https://arxiv.org/pdf/2106.05234>`__
.. math::
\text{Attn}=\text{softmax}(\dfrac{QK^T}{\sqrt{d}} \circ b)
:math:`Q` and :math:`K` are feature representations of nodes. :math:`d`
is the corresponding :attr:`feat_size`. :math:`b` is attention bias, which
can be additive or multiplicative according to the operator :math:`\circ`.
Parameters
----------
feat_size : int
Feature size.
num_heads : int
Number of attention heads, by which :attr:`feat_size` is divisible.
bias : bool, optional
If True, it uses bias for linear projection. Default: True.
attn_bias_type : str, optional
The type of attention bias used for modifying attention. Selected from
'add' or 'mul'. Default: 'add'.
* 'add' is for additive attention bias.
* 'mul' is for multiplicative attention bias.
attn_drop : float, optional
Dropout probability on attention weights. Defalt: 0.1.
Examples
--------
>>> import torch as th
>>> from dgl.nn import BiasedMHA
>>> ndata = th.rand(16, 100, 512)
>>> bias = th.rand(16, 100, 100, 8)
>>> net = BiasedMHA(feat_size=512, num_heads=8)
>>> out = net(ndata, bias)
"""
def
__init__
(
self
,
feat_size
,
num_heads
,
bias
=
True
,
attn_bias_type
=
"add"
,
attn_drop
=
0.1
,
):
super
().
__init__
()
self
.
feat_size
=
feat_size
self
.
num_heads
=
num_heads
self
.
head_dim
=
feat_size
//
num_heads
assert
(
self
.
head_dim
*
num_heads
==
feat_size
),
"feat_size must be divisible by num_heads"
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn_bias_type
=
attn_bias_type
self
.
q_proj
=
nn
.
Linear
(
feat_size
,
feat_size
,
bias
=
bias
)
self
.
k_proj
=
nn
.
Linear
(
feat_size
,
feat_size
,
bias
=
bias
)
self
.
v_proj
=
nn
.
Linear
(
feat_size
,
feat_size
,
bias
=
bias
)
self
.
out_proj
=
nn
.
Linear
(
feat_size
,
feat_size
,
bias
=
bias
)
self
.
dropout
=
nn
.
Dropout
(
p
=
attn_drop
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
"""
Initialize parameters of projection matrices, the same settings as in
the original implementation of the paper.
"""
nn
.
init
.
xavier_uniform_
(
self
.
q_proj
.
weight
,
gain
=
2
**-
0.5
)
nn
.
init
.
xavier_uniform_
(
self
.
k_proj
.
weight
,
gain
=
2
**-
0.5
)
nn
.
init
.
xavier_uniform_
(
self
.
v_proj
.
weight
,
gain
=
2
**-
0.5
)
nn
.
init
.
xavier_uniform_
(
self
.
out_proj
.
weight
)
if
self
.
out_proj
.
bias
is
not
None
:
nn
.
init
.
constant_
(
self
.
out_proj
.
bias
,
0.0
)
def
forward
(
self
,
ndata
,
attn_bias
=
None
,
attn_mask
=
None
):
"""Forward computation.
Parameters
----------
ndata : torch.Tensor
A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where
N is the maximum number of nodes.
attn_bias : torch.Tensor, optional
The attention bias used for attention modification. Shape:
(batch_size, N, N, :attr:`num_heads`).
attn_mask : torch.Tensor, optional
The attention mask used for avoiding computation on invalid
positions, where invalid positions are indicated by `True` values.
Shape: (batch_size, N, N). Note: For rows corresponding to
unexisting nodes, make sure at least one entry is set to `False` to
prevent obtaining NaNs with softmax.
Returns
-------
y : torch.Tensor
The output tensor. Shape: (batch_size, N, :attr:`feat_size`)
"""
q_h
=
self
.
q_proj
(
ndata
).
transpose
(
0
,
1
)
k_h
=
self
.
k_proj
(
ndata
).
transpose
(
0
,
1
)
v_h
=
self
.
v_proj
(
ndata
).
transpose
(
0
,
1
)
bsz
,
N
,
_
=
ndata
.
shape
q_h
=
(
q_h
.
reshape
(
N
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
*
self
.
scaling
)
k_h
=
k_h
.
reshape
(
N
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
permute
(
1
,
2
,
0
)
v_h
=
v_h
.
reshape
(
N
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
attn_weights
=
(
th
.
bmm
(
q_h
,
k_h
)
.
transpose
(
0
,
2
)
.
reshape
(
N
,
N
,
bsz
,
self
.
num_heads
)
.
transpose
(
0
,
2
)
)
if
attn_bias
is
not
None
:
if
self
.
attn_bias_type
==
"add"
:
attn_weights
+=
attn_bias
else
:
attn_weights
*=
attn_bias
if
attn_mask
is
not
None
:
attn_weights
[
attn_mask
.
to
(
th
.
bool
)]
=
float
(
"-inf"
)
attn_weights
=
F
.
softmax
(
attn_weights
.
transpose
(
0
,
2
)
.
reshape
(
N
,
N
,
bsz
*
self
.
num_heads
)
.
transpose
(
0
,
2
),
dim
=
2
,
)
attn_weights
=
self
.
dropout
(
attn_weights
)
attn
=
th
.
bmm
(
attn_weights
,
v_h
).
transpose
(
0
,
1
)
attn
=
self
.
out_proj
(
attn
.
reshape
(
N
,
bsz
,
self
.
feat_size
).
transpose
(
0
,
1
)
)
return
attn
python/dgl/nn/pytorch/gt/degree_encoder.py
0 → 100644
View file @
bb1f8850
"""Degree Encoder"""
import
torch
as
th
import
torch.nn
as
nn
from
....base
import
DGLError
class
DegreeEncoder
(
nn
.
Module
):
r
"""Degree Encoder, as introduced in
`Do Transformers Really Perform Bad for Graph Representation?
<https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__
This module is a learnable degree embedding module.
Parameters
----------
max_degree : int
Upper bound of degrees to be encoded.
Each degree will be clamped into the range [0, ``max_degree``].
embedding_dim : int
Output dimension of embedding vectors.
direction : str, optional
Degrees of which direction to be encoded,
selected from ``in``, ``out`` and ``both``.
``both`` encodes degrees from both directions
and output the addition of them.
Default : ``both``.
Example
-------
>>> import dgl
>>> from dgl.nn import DegreeEncoder
>>> g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> degree_encoder = DegreeEncoder(5, 16)
>>> degree_embedding = degree_encoder(g)
"""
def
__init__
(
self
,
max_degree
,
embedding_dim
,
direction
=
"both"
):
super
(
DegreeEncoder
,
self
).
__init__
()
self
.
direction
=
direction
if
direction
==
"both"
:
self
.
encoder1
=
nn
.
Embedding
(
max_degree
+
1
,
embedding_dim
,
padding_idx
=
0
)
self
.
encoder2
=
nn
.
Embedding
(
max_degree
+
1
,
embedding_dim
,
padding_idx
=
0
)
else
:
self
.
encoder
=
nn
.
Embedding
(
max_degree
+
1
,
embedding_dim
,
padding_idx
=
0
)
self
.
max_degree
=
max_degree
def
forward
(
self
,
g
):
"""
Parameters
----------
g : DGLGraph
A DGLGraph to be encoded. Graphs with more than one type of edges
are not allowed.
Returns
-------
Tensor
Return degree embedding vectors of shape :math:`(N, d)`,
where :math:`N` is the number of nodes in the input graph and
:math:`d` is :attr:`embedding_dim`.
"""
if
len
(
g
.
etypes
)
>
1
:
raise
DGLError
(
"The input graph should have no more than one type of edges."
)
in_degree
=
th
.
clamp
(
g
.
in_degrees
(),
min
=
0
,
max
=
self
.
max_degree
)
out_degree
=
th
.
clamp
(
g
.
out_degrees
(),
min
=
0
,
max
=
self
.
max_degree
)
if
self
.
direction
==
"in"
:
degree_embedding
=
self
.
encoder
(
in_degree
)
elif
self
.
direction
==
"out"
:
degree_embedding
=
self
.
encoder
(
out_degree
)
elif
self
.
direction
==
"both"
:
degree_embedding
=
self
.
encoder1
(
in_degree
)
+
self
.
encoder2
(
out_degree
)
else
:
raise
ValueError
(
f
'Supported direction options: "in", "out" and "both", '
f
"but got
{
self
.
direction
}
"
)
return
degree_embedding
python/dgl/nn/pytorch/gt/graphormer.py
0 → 100644
View file @
bb1f8850
"""Graphormer Layer"""
import
torch.nn
as
nn
from
.biased_mha
import
BiasedMHA
class
GraphormerLayer
(
nn
.
Module
):
r
"""Graphormer Layer with Dense Multi-Head Attention, as introduced
in `Do Transformers Really Perform Bad for Graph Representation?
<https://arxiv.org/pdf/2106.05234>`__
Parameters
----------
feat_size : int
Feature size.
hidden_size : int
Hidden size of feedforward layers.
num_heads : int
Number of attention heads, by which :attr:`feat_size` is divisible.
attn_bias_type : str, optional
The type of attention bias used for modifying attention. Selected from
'add' or 'mul'. Default: 'add'.
* 'add' is for additive attention bias.
* 'mul' is for multiplicative attention bias.
norm_first : bool, optional
If True, it performs layer normalization before attention and
feedforward operations. Otherwise, it applies layer normalization
afterwards. Default: False.
dropout : float, optional
Dropout probability. Default: 0.1.
activation : callable activation layer, optional
Activation function. Default: nn.ReLU().
Examples
--------
>>> import torch as th
>>> from dgl.nn import GraphormerLayer
>>> batch_size = 16
>>> num_nodes = 100
>>> feat_size = 512
>>> num_heads = 8
>>> nfeat = th.rand(batch_size, num_nodes, feat_size)
>>> bias = th.rand(batch_size, num_nodes, num_nodes, num_heads)
>>> net = GraphormerLayer(
feat_size=feat_size,
hidden_size=2048,
num_heads=num_heads
)
>>> out = net(nfeat, bias)
"""
def
__init__
(
self
,
feat_size
,
hidden_size
,
num_heads
,
attn_bias_type
=
"add"
,
norm_first
=
False
,
dropout
=
0.1
,
activation
=
nn
.
ReLU
(),
):
super
().
__init__
()
self
.
norm_first
=
norm_first
self
.
attn
=
BiasedMHA
(
feat_size
=
feat_size
,
num_heads
=
num_heads
,
attn_bias_type
=
attn_bias_type
,
attn_drop
=
dropout
,
)
self
.
ffn
=
nn
.
Sequential
(
nn
.
Linear
(
feat_size
,
hidden_size
),
activation
,
nn
.
Dropout
(
p
=
dropout
),
nn
.
Linear
(
hidden_size
,
feat_size
),
nn
.
Dropout
(
p
=
dropout
),
)
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout
)
self
.
attn_layer_norm
=
nn
.
LayerNorm
(
feat_size
)
self
.
ffn_layer_norm
=
nn
.
LayerNorm
(
feat_size
)
def
forward
(
self
,
nfeat
,
attn_bias
=
None
,
attn_mask
=
None
):
"""Forward computation.
Parameters
----------
nfeat : torch.Tensor
A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where
N is the maximum number of nodes.
attn_bias : torch.Tensor, optional
The attention bias used for attention modification. Shape:
(batch_size, N, N, :attr:`num_heads`).
attn_mask : torch.Tensor, optional
The attention mask used for avoiding computation on invalid
positions, where invalid positions are indicated by `True` values.
Shape: (batch_size, N, N). Note: For rows corresponding to
unexisting nodes, make sure at least one entry is set to `False` to
prevent obtaining NaNs with softmax.
Returns
-------
y : torch.Tensor
The output tensor. Shape: (batch_size, N, :attr:`feat_size`)
"""
residual
=
nfeat
if
self
.
norm_first
:
nfeat
=
self
.
attn_layer_norm
(
nfeat
)
nfeat
=
self
.
attn
(
nfeat
,
attn_bias
,
attn_mask
)
nfeat
=
self
.
dropout
(
nfeat
)
nfeat
=
residual
+
nfeat
if
not
self
.
norm_first
:
nfeat
=
self
.
attn_layer_norm
(
nfeat
)
residual
=
nfeat
if
self
.
norm_first
:
nfeat
=
self
.
ffn_layer_norm
(
nfeat
)
nfeat
=
self
.
ffn
(
nfeat
)
nfeat
=
residual
+
nfeat
if
not
self
.
norm_first
:
nfeat
=
self
.
ffn_layer_norm
(
nfeat
)
return
nfeat
python/dgl/nn/pytorch/gt/lap_pos_encoder.py
0 → 100644
View file @
bb1f8850
"""Laplacian Positional Encoder"""
import
torch
as
th
import
torch.nn
as
nn
class
LapPosEncoder
(
nn
.
Module
):
r
"""Laplacian Positional Encoder (LPE), as introduced in
`GraphGPS: General Powerful Scalable Graph Transformers
<https://arxiv.org/abs/2205.12454>`__
This module is a learned laplacian positional encoding module using
Transformer or DeepSet.
Parameters
----------
model_type : str
Encoder model type for LPE, can only be "Transformer" or "DeepSet".
num_layer : int
Number of layers in Transformer/DeepSet Encoder.
k : int
Number of smallest non-trivial eigenvectors.
dim : int
Output size of final laplacian encoding.
n_head : int, optional
Number of heads in Transformer Encoder.
Default : 1.
batch_norm : bool, optional
If True, apply batch normalization on raw laplacian positional
encoding. Default : False.
num_post_layer : int, optional
If num_post_layer > 0, apply an MLP of ``num_post_layer`` layers after
pooling. Default : 0.
Example
-------
>>> import dgl
>>> from dgl import LapPE
>>> from dgl.nn import LapPosEncoder
>>> transform = LapPE(k=5, feat_name='eigvec', eigval_name='eigval', padding=True)
>>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))
>>> g = transform(g)
>>> eigvals, eigvecs = g.ndata['eigval'], g.ndata['eigvec']
>>> transformer_encoder = LapPosEncoder(
model_type="Transformer", num_layer=3, k=5, dim=16, n_head=4
)
>>> pos_encoding = transformer_encoder(eigvals, eigvecs)
>>> deepset_encoder = LapPosEncoder(
model_type="DeepSet", num_layer=3, k=5, dim=16, num_post_layer=2
)
>>> pos_encoding = deepset_encoder(eigvals, eigvecs)
"""
def
__init__
(
self
,
model_type
,
num_layer
,
k
,
dim
,
n_head
=
1
,
batch_norm
=
False
,
num_post_layer
=
0
,
):
super
(
LapPosEncoder
,
self
).
__init__
()
self
.
model_type
=
model_type
self
.
linear
=
nn
.
Linear
(
2
,
dim
)
if
self
.
model_type
==
"Transformer"
:
encoder_layer
=
nn
.
TransformerEncoderLayer
(
d_model
=
dim
,
nhead
=
n_head
,
batch_first
=
True
)
self
.
pe_encoder
=
nn
.
TransformerEncoder
(
encoder_layer
,
num_layers
=
num_layer
)
elif
self
.
model_type
==
"DeepSet"
:
layers
=
[]
if
num_layer
==
1
:
layers
.
append
(
nn
.
ReLU
())
else
:
self
.
linear
=
nn
.
Linear
(
2
,
2
*
dim
)
layers
.
append
(
nn
.
ReLU
())
for
_
in
range
(
num_layer
-
2
):
layers
.
append
(
nn
.
Linear
(
2
*
dim
,
2
*
dim
))
layers
.
append
(
nn
.
ReLU
())
layers
.
append
(
nn
.
Linear
(
2
*
dim
,
dim
))
layers
.
append
(
nn
.
ReLU
())
self
.
pe_encoder
=
nn
.
Sequential
(
*
layers
)
else
:
raise
ValueError
(
f
"model_type '
{
model_type
}
' is not allowed, must be "
"'Transformer' or 'DeepSet'."
)
if
batch_norm
:
self
.
raw_norm
=
nn
.
BatchNorm1d
(
k
)
else
:
self
.
raw_norm
=
None
if
num_post_layer
>
0
:
layers
=
[]
if
num_post_layer
==
1
:
layers
.
append
(
nn
.
Linear
(
dim
,
dim
))
layers
.
append
(
nn
.
ReLU
())
else
:
layers
.
append
(
nn
.
Linear
(
dim
,
2
*
dim
))
layers
.
append
(
nn
.
ReLU
())
for
_
in
range
(
num_post_layer
-
2
):
layers
.
append
(
nn
.
Linear
(
2
*
dim
,
2
*
dim
))
layers
.
append
(
nn
.
ReLU
())
layers
.
append
(
nn
.
Linear
(
2
*
dim
,
dim
))
layers
.
append
(
nn
.
ReLU
())
self
.
post_mlp
=
nn
.
Sequential
(
*
layers
)
else
:
self
.
post_mlp
=
None
def
forward
(
self
,
eigvals
,
eigvecs
):
r
"""
Parameters
----------
eigvals : Tensor
Laplacian Eigenvalues of shape :math:`(N, k)`, k different
eigenvalues repeat N times, can be obtained by using `LaplacianPE`.
eigvecs : Tensor
Laplacian Eigenvectors of shape :math:`(N, k)`, can be obtained by
using `LaplacianPE`.
Returns
-------
Tensor
Return the laplacian positional encodings of shape :math:`(N, d)`,
where :math:`N` is the number of nodes in the input graph,
:math:`d` is :attr:`dim`.
"""
pos_encoding
=
th
.
cat
(
(
eigvecs
.
unsqueeze
(
2
),
eigvals
.
unsqueeze
(
2
)),
dim
=
2
).
float
()
empty_mask
=
th
.
isnan
(
pos_encoding
)
pos_encoding
[
empty_mask
]
=
0
if
self
.
raw_norm
:
pos_encoding
=
self
.
raw_norm
(
pos_encoding
)
pos_encoding
=
self
.
linear
(
pos_encoding
)
if
self
.
model_type
==
"Transformer"
:
pos_encoding
=
self
.
pe_encoder
(
src
=
pos_encoding
,
src_key_padding_mask
=
empty_mask
[:,
:,
1
]
)
else
:
pos_encoding
=
self
.
pe_encoder
(
pos_encoding
)
# Remove masked sequences.
pos_encoding
[
empty_mask
[:,
:,
1
]]
=
0
# Sum pooling.
pos_encoding
=
th
.
sum
(
pos_encoding
,
1
,
keepdim
=
False
)
# MLP post pooling.
if
self
.
post_mlp
:
pos_encoding
=
self
.
post_mlp
(
pos_encoding
)
return
pos_encoding
python/dgl/nn/pytorch/gt/path_encoder.py
0 → 100644
View file @
bb1f8850
"""Path Encoder"""
import
torch
as
th
import
torch.nn
as
nn
from
....batch
import
unbatch
from
....transforms
import
shortest_dist
class
PathEncoder
(
nn
.
Module
):
r
"""Path Encoder, as introduced in Edge Encoding of
`Do Transformers Really Perform Bad for Graph Representation?
<https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__
This module is a learnable path embedding module and encodes the shortest
path between each pair of nodes as attention bias.
Parameters
----------
max_len : int
Maximum number of edges in each path to be encoded.
Exceeding part of each path will be truncated, i.e.
truncating edges with serial number no less than :attr:`max_len`.
feat_dim : int
Dimension of edge features in the input graph.
num_heads : int, optional
Number of attention heads if multi-head attention mechanism is applied.
Default : 1.
Examples
--------
>>> import torch as th
>>> import dgl
>>> from dgl.nn import PathEncoder
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
>>> g = dgl.graph((u, v))
>>> edata = th.rand(8, 16)
>>> path_encoder = PathEncoder(2, 16, num_heads=8)
>>> out = path_encoder(g, edata)
"""
def
__init__
(
self
,
max_len
,
feat_dim
,
num_heads
=
1
):
super
().
__init__
()
self
.
max_len
=
max_len
self
.
feat_dim
=
feat_dim
self
.
num_heads
=
num_heads
self
.
embedding_table
=
nn
.
Embedding
(
max_len
*
num_heads
,
feat_dim
)
def
forward
(
self
,
g
,
edge_feat
):
"""
Parameters
----------
g : DGLGraph
A DGLGraph to be encoded, which must be a homogeneous one.
edge_feat : torch.Tensor
The input edge feature of shape :math:`(E, d)`,
where :math:`E` is the number of edges in the input graph and
:math:`d` is :attr:`feat_dim`.
Returns
-------
torch.Tensor
Return attention bias as path encoding, of shape
:math:`(B, N, N, H)`, where :math:`B` is the batch size of
the input graph, :math:`N` is the maximum number of nodes, and
:math:`H` is :attr:`num_heads`.
"""
device
=
g
.
device
g_list
=
unbatch
(
g
)
sum_num_edges
=
0
max_num_nodes
=
th
.
max
(
g
.
batch_num_nodes
())
path_encoding
=
th
.
zeros
(
len
(
g_list
),
max_num_nodes
,
max_num_nodes
,
self
.
num_heads
).
to
(
device
)
for
i
,
ubg
in
enumerate
(
g_list
):
num_nodes
=
ubg
.
num_nodes
()
num_edges
=
ubg
.
num_edges
()
edata
=
edge_feat
[
sum_num_edges
:
(
sum_num_edges
+
num_edges
)]
sum_num_edges
=
sum_num_edges
+
num_edges
edata
=
th
.
cat
(
(
edata
,
th
.
zeros
(
1
,
self
.
feat_dim
).
to
(
edata
.
device
)),
dim
=
0
)
dist
,
path
=
shortest_dist
(
ubg
,
root
=
None
,
return_paths
=
True
)
path_len
=
max
(
1
,
min
(
self
.
max_len
,
path
.
size
(
dim
=
2
)))
# shape: [n, n, l], n = num_nodes, l = path_len
shortest_path
=
path
[:,
:,
0
:
path_len
]
# shape: [n, n]
shortest_distance
=
th
.
clamp
(
dist
,
min
=
1
,
max
=
path_len
)
# shape: [n, n, l, d], d = feat_dim
path_data
=
edata
[
shortest_path
]
# shape: [l, h, d]
edge_embedding
=
self
.
embedding_table
.
weight
[
0
:
path_len
*
self
.
num_heads
].
reshape
(
path_len
,
self
.
num_heads
,
-
1
)
# [n, n, l, d] einsum [l, h, d] -> [n, n, h]
path_encoding
[
i
,
:
num_nodes
,
:
num_nodes
]
=
th
.
div
(
th
.
einsum
(
"xyld,lhd->xyh"
,
path_data
,
edge_embedding
).
permute
(
2
,
0
,
1
),
shortest_distance
,
).
permute
(
1
,
2
,
0
)
return
path_encoding
python/dgl/nn/pytorch/g
raph_transform
er.py
→
python/dgl/nn/pytorch/g
t/spatial_encod
er.py
View file @
bb1f8850
"""Torch modules for graph transformers."""
"""Spatial Encoder"""
import
math
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
...batch
import
unbatch
from
...convert
import
to_homogeneous
from
...transforms
import
shortest_dist
__all__
=
[
"DegreeEncoder"
,
"BiasedMultiheadAttention"
,
"PathEncoder"
,
"GraphormerLayer"
,
"SpatialEncoder"
,
"SpatialEncoder3d"
,
]
class
DegreeEncoder
(
nn
.
Module
):
r
"""Degree Encoder, as introduced in
`Do Transformers Really Perform Bad for Graph Representation?
<https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__
This module is a learnable degree embedding module.
Parameters
----------
max_degree : int
Upper bound of degrees to be encoded.
Each degree will be clamped into the range [0, ``max_degree``].
embedding_dim : int
Output dimension of embedding vectors.
direction : str, optional
Degrees of which direction to be encoded,
selected from ``in``, ``out`` and ``both``.
``both`` encodes degrees from both directions
and output the addition of them.
Default : ``both``.
Example
-------
>>> import dgl
>>> from dgl.nn import DegreeEncoder
>>> g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> degree_encoder = DegreeEncoder(5, 16)
>>> degree_embedding = degree_encoder(g)
"""
def
__init__
(
self
,
max_degree
,
embedding_dim
,
direction
=
"both"
):
super
(
DegreeEncoder
,
self
).
__init__
()
self
.
direction
=
direction
if
direction
==
"both"
:
self
.
degree_encoder_1
=
nn
.
Embedding
(
max_degree
+
1
,
embedding_dim
,
padding_idx
=
0
)
self
.
degree_encoder_2
=
nn
.
Embedding
(
max_degree
+
1
,
embedding_dim
,
padding_idx
=
0
)
else
:
self
.
degree_encoder
=
nn
.
Embedding
(
max_degree
+
1
,
embedding_dim
,
padding_idx
=
0
)
self
.
max_degree
=
max_degree
def
forward
(
self
,
g
):
"""
Parameters
----------
g : DGLGraph
A DGLGraph to be encoded. If it is a heterogeneous one,
it will be transformed into a homogeneous one first.
Returns
-------
Tensor
Return degree embedding vectors of shape :math:`(N, embedding_dim)`,
where :math:`N` is th number of nodes in the input graph.
"""
if
len
(
g
.
ntypes
)
>
1
or
len
(
g
.
etypes
)
>
1
:
g
=
to_homogeneous
(
g
)
in_degree
=
th
.
clamp
(
g
.
in_degrees
(),
min
=
0
,
max
=
self
.
max_degree
)
out_degree
=
th
.
clamp
(
g
.
out_degrees
(),
min
=
0
,
max
=
self
.
max_degree
)
if
self
.
direction
==
"in"
:
degree_embedding
=
self
.
degree_encoder
(
in_degree
)
elif
self
.
direction
==
"out"
:
degree_embedding
=
self
.
degree_encoder
(
out_degree
)
elif
self
.
direction
==
"both"
:
degree_embedding
=
self
.
degree_encoder_1
(
in_degree
)
+
self
.
degree_encoder_2
(
out_degree
)
else
:
raise
ValueError
(
f
'Supported direction options: "in", "out" and "both", '
f
"but got
{
self
.
direction
}
"
)
return
degree_embedding
class
PathEncoder
(
nn
.
Module
):
r
"""Path Encoder, as introduced in Edge Encoding of
`Do Transformers Really Perform Bad for Graph Representation?
<https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__
This module is a learnable path embedding module and encodes the shortest
path between each pair of nodes as attention bias.
Parameters
----------
max_len : int
Maximum number of edges in each path to be encoded.
Exceeding part of each path will be truncated, i.e.
truncating edges with serial number no less than :attr:`max_len`.
feat_dim : int
Dimension of edge features in the input graph.
num_heads : int, optional
Number of attention heads if multi-head attention mechanism is applied.
Default : 1.
Examples
--------
>>> import torch as th
>>> import dgl
>>> from dgl.nn import PathEncoder
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
>>> g = dgl.graph((u, v))
>>> edata = th.rand(8, 16)
>>> path_encoder = PathEncoder(2, 16, num_heads=8)
>>> out = path_encoder(g, edata)
"""
def
__init__
(
self
,
max_len
,
feat_dim
,
num_heads
=
1
):
super
().
__init__
()
self
.
max_len
=
max_len
self
.
feat_dim
=
feat_dim
self
.
num_heads
=
num_heads
self
.
embedding_table
=
nn
.
Embedding
(
max_len
*
num_heads
,
feat_dim
)
def
forward
(
self
,
g
,
edge_feat
):
"""
Parameters
----------
g : DGLGraph
A DGLGraph to be encoded, which must be a homogeneous one.
edge_feat : torch.Tensor
The input edge feature of shape :math:`(E, feat_dim)`,
where :math:`E` is the number of edges in the input graph.
Returns
-------
torch.Tensor
Return attention bias as path encoding,
of shape :math:`(batch_size, N, N, num_heads)`,
where :math:`N` is the maximum number of nodes
and batch_size is the batch size of the input graph.
"""
g_list
=
unbatch
(
g
)
sum_num_edges
=
0
max_num_nodes
=
th
.
max
(
g
.
batch_num_nodes
())
path_encoding
=
[]
for
ubg
in
g_list
:
num_nodes
=
ubg
.
num_nodes
()
num_edges
=
ubg
.
num_edges
()
edata
=
edge_feat
[
sum_num_edges
:
(
sum_num_edges
+
num_edges
)]
sum_num_edges
=
sum_num_edges
+
num_edges
edata
=
th
.
cat
(
(
edata
,
th
.
zeros
(
1
,
self
.
feat_dim
).
to
(
edata
.
device
)),
dim
=
0
)
dist
,
path
=
shortest_dist
(
ubg
,
root
=
None
,
return_paths
=
True
)
path_len
=
max
(
1
,
min
(
self
.
max_len
,
path
.
size
(
dim
=
2
)))
# shape: [n, n, l], n = num_nodes, l = path_len
shortest_path
=
path
[:,
:,
0
:
path_len
]
# shape: [n, n]
shortest_distance
=
th
.
clamp
(
dist
,
min
=
1
,
max
=
path_len
)
# shape: [n, n, l, d], d = feat_dim
path_data
=
edata
[
shortest_path
]
# shape: [l, h, d]
edge_embedding
=
self
.
embedding_table
.
weight
[
0
:
path_len
*
self
.
num_heads
].
reshape
(
path_len
,
self
.
num_heads
,
-
1
)
# [n, n, l, d] einsum [l, h, d] -> [n, n, h]
# [n, n, h] -> [N, N, h], N = max_num_nodes, padded with -inf
sub_encoding
=
th
.
full
(
(
max_num_nodes
,
max_num_nodes
,
self
.
num_heads
),
float
(
"-inf"
)
)
sub_encoding
[
0
:
num_nodes
,
0
:
num_nodes
]
=
th
.
div
(
th
.
einsum
(
"xyld,lhd->xyh"
,
path_data
,
edge_embedding
).
permute
(
2
,
0
,
1
),
shortest_distance
,
).
permute
(
1
,
2
,
0
)
path_encoding
.
append
(
sub_encoding
)
return
th
.
stack
(
path_encoding
,
dim
=
0
)
class
BiasedMultiheadAttention
(
nn
.
Module
):
r
"""Dense Multi-Head Attention Module with Graph Attention Bias.
Compute attention between nodes with attention bias obtained from graph
structures, as introduced in `Do Transformers Really Perform Bad for
Graph Representation? <https://arxiv.org/pdf/2106.05234>`__
.. math::
\text{Attn}=\text{softmax}(\dfrac{QK^T}{\sqrt{d}} \circ b)
:math:`Q` and :math:`K` are feature representation of nodes. :math:`d`
is the corresponding :attr:`feat_size`. :math:`b` is attention bias, which
can be additive or multiplicative according to the operator :math:`\circ`.
Parameters
----------
feat_size : int
Feature size.
num_heads : int
Number of attention heads, by which attr:`feat_size` is divisible.
bias : bool, optional
If True, it uses bias for linear projection. Default: True.
attn_bias_type : str, optional
The type of attention bias used for modifying attention. Selected from
'add' or 'mul'. Default: 'add'.
* 'add' is for additive attention bias.
* 'mul' is for multiplicative attention bias.
attn_drop : float, optional
Dropout probability on attention weights. Defalt: 0.1.
Examples
--------
>>> import torch as th
>>> from dgl.nn import BiasedMultiheadAttention
>>> ndata = th.rand(16, 100, 512)
>>> bias = th.rand(16, 100, 100, 8)
>>> net = BiasedMultiheadAttention(feat_size=512, num_heads=8)
>>> out = net(ndata, bias)
"""
def
__init__
(
self
,
feat_size
,
num_heads
,
bias
=
True
,
attn_bias_type
=
"add"
,
attn_drop
=
0.1
,
):
super
().
__init__
()
self
.
feat_size
=
feat_size
self
.
num_heads
=
num_heads
self
.
head_dim
=
feat_size
//
num_heads
assert
(
self
.
head_dim
*
num_heads
==
feat_size
),
"feat_size must be divisible by num_heads"
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn_bias_type
=
attn_bias_type
self
.
q_proj
=
nn
.
Linear
(
feat_size
,
feat_size
,
bias
=
bias
)
self
.
k_proj
=
nn
.
Linear
(
feat_size
,
feat_size
,
bias
=
bias
)
self
.
v_proj
=
nn
.
Linear
(
feat_size
,
feat_size
,
bias
=
bias
)
self
.
out_proj
=
nn
.
Linear
(
feat_size
,
feat_size
,
bias
=
bias
)
self
.
dropout
=
nn
.
Dropout
(
p
=
attn_drop
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
"""Reset parameters of projection matrices, the same settings as that in Graphormer."""
nn
.
init
.
xavier_uniform_
(
self
.
q_proj
.
weight
,
gain
=
2
**-
0.5
)
nn
.
init
.
xavier_uniform_
(
self
.
k_proj
.
weight
,
gain
=
2
**-
0.5
)
nn
.
init
.
xavier_uniform_
(
self
.
v_proj
.
weight
,
gain
=
2
**-
0.5
)
nn
.
init
.
xavier_uniform_
(
self
.
out_proj
.
weight
)
if
self
.
out_proj
.
bias
is
not
None
:
nn
.
init
.
constant_
(
self
.
out_proj
.
bias
,
0.0
)
def
forward
(
self
,
ndata
,
attn_bias
=
None
,
attn_mask
=
None
):
"""Forward computation.
Parameters
----------
ndata : torch.Tensor
A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where
N is the maximum number of nodes.
attn_bias : torch.Tensor, optional
The attention bias used for attention modification. Shape:
(batch_size, N, N, :attr:`num_heads`).
attn_mask : torch.Tensor, optional
The attention mask used for avoiding computation on invalid positions, where
invalid positions are indicated by non-zero values. Shape: (batch_size, N, N).
Returns
-------
y : torch.Tensor
The output tensor. Shape: (batch_size, N, :attr:`feat_size`)
"""
q_h
=
self
.
q_proj
(
ndata
).
transpose
(
0
,
1
)
k_h
=
self
.
k_proj
(
ndata
).
transpose
(
0
,
1
)
v_h
=
self
.
v_proj
(
ndata
).
transpose
(
0
,
1
)
bsz
,
N
,
_
=
ndata
.
shape
q_h
=
(
q_h
.
reshape
(
N
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
/
self
.
scaling
)
k_h
=
k_h
.
reshape
(
N
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
permute
(
1
,
2
,
0
)
v_h
=
v_h
.
reshape
(
N
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
attn_weights
=
(
th
.
bmm
(
q_h
,
k_h
)
.
transpose
(
0
,
2
)
.
reshape
(
N
,
N
,
bsz
,
self
.
num_heads
)
.
transpose
(
0
,
2
)
)
if
attn_bias
is
not
None
:
if
self
.
attn_bias_type
==
"add"
:
attn_weights
+=
attn_bias
else
:
attn_weights
*=
attn_bias
if
attn_mask
is
not
None
:
attn_weights
[
attn_mask
.
to
(
th
.
bool
)]
=
float
(
"-inf"
)
attn_weights
=
F
.
softmax
(
attn_weights
.
transpose
(
0
,
2
)
.
reshape
(
N
,
N
,
bsz
*
self
.
num_heads
)
.
transpose
(
0
,
2
),
dim
=
2
,
)
attn_weights
=
self
.
dropout
(
attn_weights
)
attn
=
th
.
bmm
(
attn_weights
,
v_h
).
transpose
(
0
,
1
)
attn
=
self
.
out_proj
(
attn
.
reshape
(
N
,
bsz
,
self
.
feat_size
).
transpose
(
0
,
1
)
)
return
attn
class
GraphormerLayer
(
nn
.
Module
):
r
"""Graphormer Layer with Dense Multi-Head Attention, as introduced
in `Do Transformers Really Perform Bad for Graph Representation?
<https://arxiv.org/pdf/2106.05234>`__
Parameters
----------
feat_size : int
Feature size.
hidden_size : int
Hidden size of feedforward layers.
num_heads : int
Number of attention heads, by which :attr:`feat_size` is divisible.
attn_bias_type : str, optional
The type of attention bias used for modifying attention. Selected from
'add' or 'mul'. Default: 'add'.
* 'add' is for additive attention bias.
* 'mul' is for multiplicative attention bias.
norm_first : bool, optional
If True, it performs layer normalization before attention and
feedforward operations. Otherwise, it applies layer normalization
afterwards. Default: False.
dropout : float, optional
Dropout probability. Default: 0.1.
activation : callable activation layer, optional
Activation function. Default: nn.ReLU().
Examples
--------
>>> import torch as th
>>> from dgl.nn import GraphormerLayer
>>> batch_size = 16
>>> num_nodes = 100
>>> feat_size = 512
>>> num_heads = 8
>>> nfeat = th.rand(batch_size, num_nodes, feat_size)
>>> bias = th.rand(batch_size, num_nodes, num_nodes, num_heads)
>>> net = GraphormerLayer(
feat_size=feat_size,
hidden_size=2048,
num_heads=num_heads
)
>>> out = net(nfeat, bias)
"""
def
__init__
(
self
,
feat_size
,
hidden_size
,
num_heads
,
attn_bias_type
=
"add"
,
norm_first
=
False
,
dropout
=
0.1
,
activation
=
nn
.
ReLU
(),
):
super
().
__init__
()
self
.
norm_first
=
norm_first
self
.
attn
=
BiasedMultiheadAttention
(
feat_size
=
feat_size
,
num_heads
=
num_heads
,
attn_bias_type
=
attn_bias_type
,
attn_drop
=
dropout
,
)
self
.
ffn
=
nn
.
Sequential
(
nn
.
Linear
(
feat_size
,
hidden_size
),
activation
,
nn
.
Dropout
(
p
=
dropout
),
nn
.
Linear
(
hidden_size
,
feat_size
),
nn
.
Dropout
(
p
=
dropout
),
)
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout
)
self
.
attn_layer_norm
=
nn
.
LayerNorm
(
feat_size
)
self
.
ffn_layer_norm
=
nn
.
LayerNorm
(
feat_size
)
def
forward
(
self
,
nfeat
,
attn_bias
=
None
,
attn_mask
=
None
):
"""Forward computation.
Parameters
----------
nfeat : torch.Tensor
A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where
N is the maximum number of nodes.
attn_bias : torch.Tensor, optional
The attention bias used for attention modification. Shape:
(batch_size, N, N, :attr:`num_heads`).
attn_mask : torch.Tensor, optional
The attention mask used for avoiding computation on invalid
positions. Shape: (batch_size, N, N).
Returns
-------
y : torch.Tensor
The output tensor. Shape: (batch_size, N, :attr:`feat_size`)
"""
residual
=
nfeat
if
self
.
norm_first
:
nfeat
=
self
.
attn_layer_norm
(
nfeat
)
nfeat
=
self
.
attn
(
nfeat
,
attn_bias
,
attn_mask
)
nfeat
=
self
.
dropout
(
nfeat
)
nfeat
=
residual
+
nfeat
if
not
self
.
norm_first
:
nfeat
=
self
.
attn_layer_norm
(
nfeat
)
residual
=
nfeat
if
self
.
norm_first
:
nfeat
=
self
.
ffn_layer_norm
(
nfeat
)
nfeat
=
self
.
ffn
(
nfeat
)
nfeat
=
residual
+
nfeat
if
not
self
.
norm_first
:
nfeat
=
self
.
ffn_layer_norm
(
nfeat
)
return
nfeat
from
....batch
import
unbatch
from
....transforms
import
shortest_dist
class
SpatialEncoder
(
nn
.
Module
):
r
"""Spatial Encoder, as introduced in
`Do Transformers Really Perform Bad for Graph Representation?
<https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__
This module is a learnable spatial embedding module which encodes
This module is a learnable spatial embedding module, which encodes
the shortest distance between each node pair for attention bias.
Parameters
...
...
@@ -523,9 +70,11 @@ class SpatialEncoder(nn.Module):
device
=
g
.
device
g_list
=
unbatch
(
g
)
max_num_nodes
=
th
.
max
(
g
.
batch_num_nodes
())
spatial_encoding
=
[]
spatial_encoding
=
th
.
zeros
(
len
(
g_list
),
max_num_nodes
,
max_num_nodes
,
self
.
num_heads
).
to
(
device
)
for
ubg
in
g_list
:
for
i
,
ubg
in
enumerate
(
g_list
)
:
num_nodes
=
ubg
.
num_nodes
()
dist
=
(
th
.
clamp
(
...
...
@@ -537,19 +86,15 @@ class SpatialEncoder(nn.Module):
)
# shape: [n, n, h], n = num_nodes, h = num_heads
dist_embedding
=
self
.
embedding_table
(
dist
)
# [n, n, h] -> [N, N, h], N = max_num_nodes, padded with -inf
padded_encoding
=
th
.
full
(
(
max_num_nodes
,
max_num_nodes
,
self
.
num_heads
),
float
(
"-inf"
)
).
to
(
device
)
padded_encoding
[
0
:
num_nodes
,
0
:
num_nodes
]
=
dist_embedding
spatial_encoding
.
append
(
padded_encoding
)
return
th
.
stack
(
spatial_encoding
,
dim
=
0
)
spatial_encoding
[
i
,
:
num_nodes
,
:
num_nodes
]
=
dist_embedding
return
spatial_encoding
class
SpatialEncoder3d
(
nn
.
Module
):
r
"""3D Spatial Encoder, as introduced in
`One Transformer Can Understand Both 2D & 3D Molecular Data
<https://arxiv.org/pdf/2210.01765.pdf>`__
This module encodes pair-wise relation between atom pair :math:`(i,j)` in
the 3D geometric space, according to the Gaussian Basis Kernel function:
...
...
@@ -631,6 +176,7 @@ class SpatialEncoder3d(nn.Module):
be a tensor in shape :math:`(N,)`. The scaling factors of
each pair of nodes are determined by their node types.
* Otherwise, :attr:`node_type` should be None.
Returns
-------
torch.Tensor
...
...
@@ -643,14 +189,16 @@ class SpatialEncoder3d(nn.Module):
device
=
g
.
device
g_list
=
unbatch
(
g
)
max_num_nodes
=
th
.
max
(
g
.
batch_num_nodes
())
spatial_encoding
=
[]
spatial_encoding
=
th
.
zeros
(
len
(
g_list
),
max_num_nodes
,
max_num_nodes
,
self
.
num_heads
).
to
(
device
)
sum_num_nodes
=
0
if
(
self
.
max_node_type
==
1
)
!=
(
node_type
is
None
):
raise
ValueError
(
"input node_type should be None if and only if "
"max_node_type is 1."
)
for
ubg
in
g_list
:
for
i
,
ubg
in
enumerate
(
g_list
)
:
num_nodes
=
ubg
.
num_nodes
()
sub_coord
=
coord
[
sum_num_nodes
:
sum_num_nodes
+
num_nodes
]
# shape: [n, n], n = num_nodes
...
...
@@ -701,11 +249,6 @@ class SpatialEncoder3d(nn.Module):
encoding
=
F
.
gelu
(
encoding
)
# [n, n, k] -> [n, n, a], a = num_heads
encoding
=
self
.
linear_layer_2
(
encoding
)
# [n, n, a] -> [N, N, a], N = max_num_nodes, padded with -inf
padded_encoding
=
th
.
full
(
(
max_num_nodes
,
max_num_nodes
,
self
.
num_heads
),
float
(
"-inf"
)
).
to
(
device
)
padded_encoding
[
0
:
num_nodes
,
0
:
num_nodes
]
=
encoding
spatial_encoding
.
append
(
padded_encoding
)
spatial_encoding
[
i
,
:
num_nodes
,
:
num_nodes
]
=
encoding
sum_num_nodes
+=
num_nodes
return
th
.
stack
(
spatial_encoding
,
dim
=
0
)
return
spatial_encoding
python/dgl/nn/pytorch/utils.py
View file @
bb1f8850
...
...
@@ -554,155 +554,3 @@ class LabelPropagation(nn.Module):
y
[
mask
]
=
labels
[
mask
]
return
y
class
LaplacianPosEnc
(
nn
.
Module
):
r
"""Laplacian Positional Encoder (LPE), as introduced in
`GraphGPS: General Powerful Scalable Graph Transformers
<https://arxiv.org/abs/2205.12454>`__
This module is a learned laplacian positional encoding module using Transformer or DeepSet.
Parameters
----------
model_type : str
Encoder model type for LPE, can only be "Transformer" or "DeepSet".
num_layer : int
Number of layers in Transformer/DeepSet Encoder.
k : int
Number of smallest non-trivial eigenvectors.
lpe_dim : int
Output size of final laplacian encoding.
n_head : int, optional
Number of heads in Transformer Encoder.
Default : 1.
batch_norm : bool, optional
If True, apply batch normalization on raw LaplacianPE.
Default : False.
num_post_layer : int, optional
If num_post_layer > 0, apply an MLP of ``num_post_layer`` layers after pooling.
Default : 0.
Example
-------
>>> import dgl
>>> from dgl import LaplacianPE
>>> from dgl.nn import LaplacianPosEnc
>>> transform = LaplacianPE(k=5, feat_name='eigvec', eigval_name='eigval', padding=True)
>>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))
>>> g = transform(g)
>>> EigVals, EigVecs = g.ndata['eigval'], g.ndata['eigvec']
>>> TransformerLPE = LaplacianPosEnc(model_type="Transformer", num_layer=3, k=5,
lpe_dim=16, n_head=4)
>>> PosEnc = TransformerLPE(EigVals, EigVecs)
>>> DeepSetLPE = LaplacianPosEnc(model_type="DeepSet", num_layer=3, k=5,
lpe_dim=16, num_post_layer=2)
>>> PosEnc = DeepSetLPE(EigVals, EigVecs)
"""
def
__init__
(
self
,
model_type
,
num_layer
,
k
,
lpe_dim
,
n_head
=
1
,
batch_norm
=
False
,
num_post_layer
=
0
,
):
super
(
LaplacianPosEnc
,
self
).
__init__
()
self
.
model_type
=
model_type
self
.
linear
=
nn
.
Linear
(
2
,
lpe_dim
)
if
self
.
model_type
==
"Transformer"
:
encoder_layer
=
nn
.
TransformerEncoderLayer
(
d_model
=
lpe_dim
,
nhead
=
n_head
,
batch_first
=
True
)
self
.
pe_encoder
=
nn
.
TransformerEncoder
(
encoder_layer
,
num_layers
=
num_layer
)
elif
self
.
model_type
==
"DeepSet"
:
layers
=
[]
if
num_layer
==
1
:
layers
.
append
(
nn
.
ReLU
())
else
:
self
.
linear
=
nn
.
Linear
(
2
,
2
*
lpe_dim
)
layers
.
append
(
nn
.
ReLU
())
for
_
in
range
(
num_layer
-
2
):
layers
.
append
(
nn
.
Linear
(
2
*
lpe_dim
,
2
*
lpe_dim
))
layers
.
append
(
nn
.
ReLU
())
layers
.
append
(
nn
.
Linear
(
2
*
lpe_dim
,
lpe_dim
))
layers
.
append
(
nn
.
ReLU
())
self
.
pe_encoder
=
nn
.
Sequential
(
*
layers
)
else
:
raise
ValueError
(
f
"model_type '
{
model_type
}
' is not allowed, must be 'Transformer'"
"or 'DeepSet'."
)
if
batch_norm
:
self
.
raw_norm
=
nn
.
BatchNorm1d
(
k
)
else
:
self
.
raw_norm
=
None
if
num_post_layer
>
0
:
layers
=
[]
if
num_post_layer
==
1
:
layers
.
append
(
nn
.
Linear
(
lpe_dim
,
lpe_dim
))
layers
.
append
(
nn
.
ReLU
())
else
:
layers
.
append
(
nn
.
Linear
(
lpe_dim
,
2
*
lpe_dim
))
layers
.
append
(
nn
.
ReLU
())
for
_
in
range
(
num_post_layer
-
2
):
layers
.
append
(
nn
.
Linear
(
2
*
lpe_dim
,
2
*
lpe_dim
))
layers
.
append
(
nn
.
ReLU
())
layers
.
append
(
nn
.
Linear
(
2
*
lpe_dim
,
lpe_dim
))
layers
.
append
(
nn
.
ReLU
())
self
.
post_mlp
=
nn
.
Sequential
(
*
layers
)
else
:
self
.
post_mlp
=
None
def
forward
(
self
,
EigVals
,
EigVecs
):
r
"""
Parameters
----------
EigVals : Tensor
Laplacian Eigenvalues of shape :math:`(N, k)`, k different eigenvalues repeat N times,
can be obtained by using `LaplacianPE`.
EigVecs : Tensor
Laplacian Eigenvectors of shape :math:`(N, k)`, can be obtained by using `LaplacianPE`.
Returns
-------
Tensor
Return the laplacian positional encodings of shape :math:`(N, lpe_dim)`,
where :math:`N` is the number of nodes in the input graph.
"""
PosEnc
=
th
.
cat
(
(
EigVecs
.
unsqueeze
(
2
),
EigVals
.
unsqueeze
(
2
)),
dim
=
2
).
float
()
empty_mask
=
th
.
isnan
(
PosEnc
)
PosEnc
[
empty_mask
]
=
0
if
self
.
raw_norm
:
PosEnc
=
self
.
raw_norm
(
PosEnc
)
PosEnc
=
self
.
linear
(
PosEnc
)
if
self
.
model_type
==
"Transformer"
:
PosEnc
=
self
.
pe_encoder
(
src
=
PosEnc
,
src_key_padding_mask
=
empty_mask
[:,
:,
1
]
)
else
:
PosEnc
=
self
.
pe_encoder
(
PosEnc
)
# Remove masked sequences
PosEnc
[
empty_mask
[:,
:,
1
]]
=
0
# Sum pooling
PosEnc
=
th
.
sum
(
PosEnc
,
1
,
keepdim
=
False
)
# MLP post pooling
if
self
.
post_mlp
:
PosEnc
=
self
.
post_mlp
(
PosEnc
)
return
PosEnc
python/dgl/transforms/functional.py
View file @
bb1f8850
...
...
@@ -84,6 +84,7 @@ __all__ = [
"radius_graph"
,
"random_walk_pe"
,
"laplacian_pe"
,
"lap_pe"
,
"to_half"
,
"to_float"
,
"to_double"
,
...
...
@@ -3593,7 +3594,7 @@ def random_walk_pe(g, k, eweight_name=None):
return
PE
def
lap
lacian
_pe
(
g
,
k
,
padding
=
False
,
return_eigval
=
False
):
def
lap_pe
(
g
,
k
,
padding
=
False
,
return_eigval
=
False
):
r
"""Laplacian Positional Encoding, as introduced in
`Benchmarking Graph Neural Networks
<https://arxiv.org/abs/2003.00982>`__
...
...
@@ -3606,13 +3607,12 @@ def laplacian_pe(g, k, padding=False, return_eigval=False):
g : DGLGraph
The input graph. Must be homogeneous and bidirected.
k : int
Number of smallest non-trivial eigenvectors to use for positional encoding.
Number of smallest non-trivial eigenvectors to use for positional
encoding.
padding : bool, optional
If False, raise an exception when k>=n.
Otherwise, add zero paddings in the end of eigenvectors and 'nan' paddings
in the end of eigenvalues when k>=n.
Default: False.
n is the number of nodes in the given graph.
If False, raise an exception when k>=n. Otherwise, add zero paddings
in the end of eigenvectors and 'nan' paddings in the end of eigenvalues
when k>=n. Default: False. n is the number of nodes in the given graph.
return_eigval : bool, optional
If True, return laplacian eigenvalues together with eigenvectors.
Otherwise, return laplacian eigenvectors only.
...
...
@@ -3621,26 +3621,27 @@ def laplacian_pe(g, k, padding=False, return_eigval=False):
Returns
-------
Tensor or (Tensor, Tensor)
Return the laplacian positional encodings of shape :math:`(N, k)`, where :math:`N` is the
number of nodes in the input graph, when :attr:`return_eigval` is False. The eigenvalues
of shape :math:`N` is additionally returned as the second element when :attr:`return_eigval`
Return the laplacian positional encodings of shape :math:`(N, k)`,
where :math:`N` is the number of nodes in the input graph, when
:attr:`return_eigval` is False. The eigenvalues of shape :math:`N` is
additionally returned as the second element when :attr:`return_eigval`
is True.
Example
-------
>>> import dgl
>>> g = dgl.graph(([0,1,2,3,1,2,3,0], [1,2,3,0,0,1,2,3]))
>>> dgl.lap
lacian
_pe(g, 2)
>>> dgl.lap_pe(g, 2)
tensor([[ 7.0711e-01, -6.4921e-17],
[ 3.0483e-16, -7.0711e-01],
[-7.0711e-01, -2.4910e-16],
[ 9.9288e-17, 7.0711e-01]])
>>> dgl.lap
lacian
_pe(g, 5, padding=True)
>>> dgl.lap_pe(g, 5, padding=True)
tensor([[ 7.0711e-01, -6.4921e-17, 5.0000e-01, 0.0000e+00, 0.0000e+00],
[ 3.0483e-16, -7.0711e-01, -5.0000e-01, 0.0000e+00, 0.0000e+00],
[-7.0711e-01, -2.4910e-16, 5.0000e-01, 0.0000e+00, 0.0000e+00],
[ 9.9288e-17, 7.0711e-01, -5.0000e-01, 0.0000e+00, 0.0000e+00]])
>>> dgl.lap
lacian
_pe(g, 5, padding=True, return_eigval=True)
>>> dgl.lap_pe(g, 5, padding=True, return_eigval=True)
(tensor([[-7.0711e-01, 6.4921e-17, -5.0000e-01, 0.0000e+00, 0.0000e+00],
[-3.0483e-16, 7.0711e-01, 5.0000e-01, 0.0000e+00, 0.0000e+00],
[ 7.0711e-01, 2.4910e-16, -5.0000e-01, 0.0000e+00, 0.0000e+00],
...
...
@@ -3651,8 +3652,8 @@ def laplacian_pe(g, k, padding=False, return_eigval=False):
n
=
g
.
num_nodes
()
if
not
padding
and
n
<=
k
:
assert
(
"the number of eigenvectors k must be smaller than the number of
nodes n,
"
+
f
"
{
k
}
and
{
n
}
detected."
"the number of eigenvectors k must be smaller than the number of "
+
f
"
nodes n,
{
k
}
and
{
n
}
detected."
)
# get laplacian matrix as I - D^-0.5 * A * D^-0.5
...
...
@@ -3689,6 +3690,12 @@ def laplacian_pe(g, k, padding=False, return_eigval=False):
return
PE
def
laplacian_pe
(
g
,
k
,
padding
=
False
,
return_eigval
=
False
):
r
"""Alias of `dgl.lap_pe`."""
dgl_warning
(
"dgl.laplacian_pe will be deprecated. Use dgl.lap_pe please."
)
return
lap_pe
(
g
,
k
,
padding
,
return_eigval
)
def
to_half
(
g
):
r
"""Cast this graph to use float16 (half-precision) for any
floating-point edge and node feature data.
...
...
python/dgl/transforms/module.py
View file @
bb1f8850
...
...
@@ -19,7 +19,7 @@
from
scipy.linalg
import
expm
from
..
import
backend
as
F
,
convert
,
function
as
fn
,
utils
from
..base
import
DGLError
from
..base
import
dgl_warning
,
DGLError
from
.
import
functional
try
:
...
...
@@ -34,6 +34,7 @@ __all__ = [
"FeatMask"
,
"RandomWalkPE"
,
"LaplacianPE"
,
"LapPE"
,
"AddSelfLoop"
,
"RemoveSelfLoop"
,
"AddReverse"
,
...
...
@@ -419,7 +420,7 @@ class RandomWalkPE(BaseTransform):
return
g
class
Lap
lacian
PE
(
BaseTransform
):
class
LapPE
(
BaseTransform
):
r
"""Laplacian Positional Encoding, as introduced in
`Benchmarking Graph Neural Networks
<https://arxiv.org/abs/2003.00982>`__
...
...
@@ -433,23 +434,21 @@ class LaplacianPE(BaseTransform):
feat_name : str, optional
Name to store the computed positional encodings in ndata.
eigval_name : str, optional
If None, store laplacian eigenvectors only.
Otherwise, it's the name to store corresponding laplacian eigenvalues in ndata.
Default: None.
If None, store laplacian eigenvectors only. Otherwise, it's the name to
store corresponding laplacian eigenvalues in ndata. Default: None.
padding : bool, optional
If False, raise an exception when k>=n.
Otherwise, add zero paddings in the end of eigenvectors and 'nan' paddings
in the end of eigenvalues when k>=n.
Default: False.
Otherwise, add zero paddings in the end of eigenvectors and 'nan'
paddings in the end of eigenvalues when k>=n. Default: False.
n is the number of nodes in the given graph.
Example
-------
>>> import dgl
>>> from dgl import Lap
lacian
PE
>>> transform1 = Lap
lacian
PE(k=3)
>>> transform2 = Lap
lacian
PE(k=5, padding=True)
>>> transform3 = Lap
lacian
PE(k=5, feat_name='eigvec', eigval_name='eigval', padding=True)
>>> from dgl import LapPE
>>> transform1 = LapPE(k=3)
>>> transform2 = LapPE(k=5, padding=True)
>>> transform3 = LapPE(k=5, feat_name='eigvec', eigval_name='eigval', padding=True)
>>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))
>>> g1 = transform1(g)
>>> print(g1.ndata['PE'])
...
...
@@ -488,18 +487,26 @@ class LaplacianPE(BaseTransform):
def
__call__
(
self
,
g
):
if
self
.
eigval_name
:
PE
,
eigval
=
functional
.
lap
lacian
_pe
(
PE
,
eigval
=
functional
.
lap_pe
(
g
,
k
=
self
.
k
,
padding
=
self
.
padding
,
return_eigval
=
True
)
eigval
=
F
.
repeat
(
F
.
reshape
(
eigval
,
[
1
,
-
1
]),
g
.
num_nodes
(),
dim
=
0
)
g
.
ndata
[
self
.
eigval_name
]
=
F
.
copy_to
(
eigval
,
g
.
device
)
else
:
PE
=
functional
.
lap
lacian
_pe
(
g
,
k
=
self
.
k
,
padding
=
self
.
padding
)
PE
=
functional
.
lap_pe
(
g
,
k
=
self
.
k
,
padding
=
self
.
padding
)
g
.
ndata
[
self
.
feat_name
]
=
F
.
copy_to
(
PE
,
g
.
device
)
return
g
class
LaplacianPE
(
LapPE
):
r
"""Alias of `LapPE`."""
def
__init__
(
self
,
k
,
feat_name
=
"PE"
,
eigval_name
=
None
,
padding
=
False
):
super
().
__init__
(
k
,
feat_name
,
eigval_name
,
padding
)
dgl_warning
(
"LaplacianPE will be deprecated. Use LapPE please."
)
class
AddSelfLoop
(
BaseTransform
):
r
"""Add self-loops for each node in the graph and return a new graph.
...
...
tests/python/common/transforms/test_transform.py
View file @
bb1f8850
...
...
@@ -3065,7 +3065,7 @@ def test_module_random_walk_pe(idtype):
@
parametrize_idtype
def
test_module_lap
lacian
_pe
(
idtype
):
def
test_module_lap_pe
(
idtype
):
g
=
dgl
.
graph
(
([
2
,
1
,
0
,
3
,
1
,
1
],
[
3
,
1
,
1
,
2
,
1
,
0
]),
idtype
=
idtype
,
device
=
F
.
ctx
()
)
...
...
@@ -3090,7 +3090,7 @@ def test_module_laplacian_pe(idtype):
)
# without padding (k<n)
transform
=
dgl
.
Lap
lacian
PE
(
2
,
feat_name
=
"lappe"
)
transform
=
dgl
.
LapPE
(
2
,
feat_name
=
"lappe"
)
new_g
=
transform
(
g
)
# tensorflow has no abs() api
if
dgl
.
backend
.
backend_name
==
"tensorflow"
:
...
...
@@ -3100,7 +3100,7 @@ def test_module_laplacian_pe(idtype):
assert
F
.
allclose
(
new_g
.
ndata
[
"lappe"
].
abs
(),
tgt_pe
[:,
:
2
])
# with padding (k>=n)
transform
=
dgl
.
Lap
lacian
PE
(
5
,
feat_name
=
"lappe"
,
padding
=
True
)
transform
=
dgl
.
LapPE
(
5
,
feat_name
=
"lappe"
,
padding
=
True
)
new_g
=
transform
(
g
)
# tensorflow has no abs() api
if
dgl
.
backend
.
backend_name
==
"tensorflow"
:
...
...
@@ -3110,7 +3110,7 @@ def test_module_laplacian_pe(idtype):
assert
F
.
allclose
(
new_g
.
ndata
[
"lappe"
].
abs
(),
tgt_pe
)
# with eigenvalues
transform
=
dgl
.
Lap
lacian
PE
(
transform
=
dgl
.
LapPE
(
5
,
feat_name
=
"lappe"
,
eigval_name
=
"eigval"
,
padding
=
True
)
new_g
=
transform
(
g
)
...
...
tests/python/pytorch/nn/test_nn.py
View file @
bb1f8850
...
...
@@ -2227,25 +2227,9 @@ def test_degree_encoder(max_degree, embedding_dim, direction):
th
.
tensor
([
1
,
2
,
3
,
0
,
3
,
0
,
0
,
1
]),
)
)
# test heterograph
hg
=
dgl
.
heterograph
(
{
(
"drug"
,
"interacts"
,
"drug"
):
(
th
.
tensor
([
0
,
1
]),
th
.
tensor
([
1
,
2
]),
),
(
"drug"
,
"interacts"
,
"gene"
):
(
th
.
tensor
([
0
,
1
]),
th
.
tensor
([
2
,
3
]),
),
(
"drug"
,
"treats"
,
"disease"
):
(
th
.
tensor
([
1
]),
th
.
tensor
([
2
])),
}
)
model
=
nn
.
DegreeEncoder
(
max_degree
,
embedding_dim
,
direction
=
direction
)
de_g
=
model
(
g
)
de_hg
=
model
(
hg
)
assert
de_g
.
shape
==
(
4
,
embedding_dim
)
assert
de_hg
.
shape
==
(
10
,
embedding_dim
)
@
parametrize_idtype
...
...
@@ -2279,7 +2263,7 @@ def test_MetaPath2Vec(idtype):
@
pytest
.
mark
.
parametrize
(
"n_head"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"batch_norm"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_post_layer"
,
[
0
,
1
,
2
])
def
test_Lap
lacian
PosEnc
(
def
test_LapPosEnc
oder
(
num_layer
,
k
,
lpe_dim
,
n_head
,
batch_norm
,
num_post_layer
):
ctx
=
F
.
ctx
()
...
...
@@ -2288,12 +2272,12 @@ def test_LaplacianPosEnc(
EigVals
=
th
.
randn
((
num_nodes
,
k
)).
to
(
ctx
)
EigVecs
=
th
.
randn
((
num_nodes
,
k
)).
to
(
ctx
)
model
=
nn
.
Lap
lacian
PosEnc
(
model
=
nn
.
LapPosEnc
oder
(
"Transformer"
,
num_layer
,
k
,
lpe_dim
,
n_head
,
batch_norm
,
num_post_layer
).
to
(
ctx
)
assert
model
(
EigVals
,
EigVecs
).
shape
==
(
num_nodes
,
lpe_dim
)
model
=
nn
.
Lap
lacian
PosEnc
(
model
=
nn
.
LapPosEnc
oder
(
"DeepSet"
,
num_layer
,
k
,
...
...
@@ -2309,16 +2293,12 @@ def test_LaplacianPosEnc(
@
pytest
.
mark
.
parametrize
(
"bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"attn_bias_type"
,
[
"add"
,
"mul"
])
@
pytest
.
mark
.
parametrize
(
"attn_drop"
,
[
0.1
,
0.5
])
def
test_BiasedMultiheadAttention
(
feat_size
,
num_heads
,
bias
,
attn_bias_type
,
attn_drop
):
def
test_BiasedMHA
(
feat_size
,
num_heads
,
bias
,
attn_bias_type
,
attn_drop
):
ndata
=
th
.
rand
(
16
,
100
,
feat_size
)
attn_bias
=
th
.
rand
(
16
,
100
,
100
,
num_heads
)
attn_mask
=
th
.
rand
(
16
,
100
,
100
)
<
0.5
net
=
nn
.
BiasedMultiheadAttention
(
feat_size
,
num_heads
,
bias
,
attn_bias_type
,
attn_drop
)
net
=
nn
.
BiasedMHA
(
feat_size
,
num_heads
,
bias
,
attn_bias_type
,
attn_drop
)
out
=
net
(
ndata
,
attn_bias
,
attn_mask
)
assert
out
.
shape
==
(
16
,
100
,
feat_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment