Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b51fb6f6
"vscode:/vscode.git/clone" did not exist on "07c0fe4b87a07fc1b42bac738f013b78833559ae"
Unverified
Commit
b51fb6f6
authored
Jun 29, 2023
by
rudongyu
Committed by
GitHub
Jun 29, 2023
Browse files
[NN] Refactor SpatialEncoder3d (#5894)
parent
2ef90be0
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
108 additions
and
126 deletions
+108
-126
python/dgl/nn/pytorch/gt/spatial_encoder.py
python/dgl/nn/pytorch/gt/spatial_encoder.py
+77
-114
tests/python/pytorch/nn/test_nn.py
tests/python/pytorch/nn/test_nn.py
+31
-12
No files found.
python/dgl/nn/pytorch/gt/spatial_encoder.py
View file @
b51fb6f6
"""Spatial Encoder"""
import
math
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
....batch
import
unbatch
def
gaussian
(
x
,
mean
,
std
):
"""compute gaussian basis kernel function"""
const_pi
=
3.14159
a
=
(
2
*
const_pi
)
**
0.5
return
th
.
exp
(
-
0.5
*
(((
x
-
mean
)
/
std
)
**
2
))
/
(
a
*
std
)
class
SpatialEncoder
(
nn
.
Module
):
...
...
@@ -87,30 +90,32 @@ class SpatialEncoder3d(nn.Module):
`One Transformer Can Understand Both 2D & 3D Molecular Data
<https://arxiv.org/pdf/2210.01765.pdf>`__
This module encodes pair-wise relation between
atom
pair :math:`(i,j)` in
This module encodes pair-wise relation between
node
pair :math:`(i,j)` in
the 3D geometric space, according to the Gaussian Basis Kernel function:
:math:`\psi _{(i,j)} ^k =
-
\frac{1}{\sqrt{2\pi} \lvert \sigma^k \rvert}
:math:`\psi _{(i,j)} ^k = \frac{1}{\sqrt{2\pi} \lvert \sigma^k \rvert}
\exp{\left ( -\frac{1}{2} \left( \frac{\gamma_{(i,j)} \lvert \lvert r_i -
r_j \rvert \rvert + \beta_{(i,j)} - \mu^k}{\lvert \sigma^k \rvert} \right)
^2 \right)},k=1,...,K,`
where :math:`K` is the number of Gaussian Basis kernels.
:math:`r_i` is the Cartesian coordinate of atom :math:`i`.
:math:`\gamma_{(i,j)}, \beta_{(i,j)}` are learnable scaling factors of
the Gaussian Basis kernels.
where :math:`K` is the number of Gaussian Basis kernels. :math:`r_i` is the
Cartesian coordinate of node :math:`i`.
:math:`\gamma_{(i,j)}, \beta_{(i,j)}` are learnable scaling factors and
biases determined by node types. :math:`\mu^k, \sigma^k` are learnable
centers and standard deviations of the Gaussian Basis kernels.
Parameters
----------
num_kernels : int
Number of Gaussian Basis Kernels to be applied.
Each Gaussian Basis
Kernel contains a learnable kernel center
and a learnable scaling factor
.
Number of Gaussian Basis Kernels to be applied.
Each Gaussian Basis
Kernel contains a learnable kernel center
and a learnable standard
deviation
.
num_heads : int, optional
Number of attention heads if multi-head attention mechanism is applied.
Default : 1.
max_node_type : int, optional
Maximum number of node types. Default : 1.
Maximum number of node types. Each node type has a corresponding
learnable scaling factor and a bias. Default : 100.
Examples
--------
...
...
@@ -118,129 +123,87 @@ class SpatialEncoder3d(nn.Module):
>>> import dgl
>>> from dgl.nn import SpatialEncoder3d
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
>>> g = dgl.graph((u, v))
>>> coordinate = th.rand(4, 3)
>>> node_type = th.tensor([1, 0, 2, 1])
>>> coordinate = th.rand(1, 4, 3)
>>> node_type = th.tensor([[1, 0, 2, 1]])
>>> spatial_encoder = SpatialEncoder3d(num_kernels=4,
... num_heads=8,
... max_node_type=3)
>>> out = spatial_encoder(
g,
coordinate, node_type=node_type)
>>> out = spatial_encoder(coordinate, node_type=node_type)
>>> print(out.shape)
torch.Size([1, 4, 4, 8])
"""
def
__init__
(
self
,
num_kernels
,
num_heads
=
1
,
max_node_type
=
1
):
def
__init__
(
self
,
num_kernels
,
num_heads
=
1
,
max_node_type
=
1
00
):
super
().
__init__
()
self
.
num_kernels
=
num_kernels
self
.
num_heads
=
num_heads
self
.
max_node_type
=
max_node_type
self
.
gaussian_
means
=
nn
.
Embedding
(
1
,
num_kernels
)
self
.
gaussian_
stds
=
nn
.
Embedding
(
1
,
num_kernels
)
self
.
means
=
nn
.
Parameter
(
th
.
empty
(
num_kernels
)
)
self
.
stds
=
nn
.
Parameter
(
th
.
empty
(
num_kernels
)
)
self
.
linear_layer_1
=
nn
.
Linear
(
num_kernels
,
num_kernels
)
self
.
linear_layer_2
=
nn
.
Linear
(
num_kernels
,
num_heads
)
if
max_node_type
==
1
:
self
.
mul
=
nn
.
Embedding
(
1
,
1
)
self
.
bias
=
nn
.
Embedding
(
1
,
1
)
else
:
self
.
mul
=
nn
.
Embedding
(
max_node_type
+
1
,
2
)
self
.
bias
=
nn
.
Embedding
(
max_node_type
+
1
,
2
)
nn
.
init
.
uniform_
(
self
.
gaussian_means
.
weight
,
0
,
3
)
nn
.
init
.
uniform_
(
self
.
gaussian_stds
.
weight
,
0
,
3
)
nn
.
init
.
constant_
(
self
.
mul
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
bias
.
weight
,
1
)
def
forward
(
self
,
g
,
coord
,
node_type
=
None
):
# There are 2 * max_node_type + 3 pairs of gamma and beta parameters:
# 1. Parameters at position 0 are for default gamma/beta when no node
# type is given
# 2. Parameters at position 1 to max_node_type+1 are for src node types.
# (position 1 is for padded unexisting nodes)
# 3. Parameters at position max_node_type+2 to 2*max_node_type+2 are
# for tgt node types. (position max_node_type+2 is for padded)
# unexisting nodes)
self
.
gamma
=
nn
.
Embedding
(
2
*
max_node_type
+
3
,
1
,
padding_idx
=
0
)
self
.
beta
=
nn
.
Embedding
(
2
*
max_node_type
+
3
,
1
,
padding_idx
=
0
)
nn
.
init
.
uniform_
(
self
.
means
,
0
,
3
)
nn
.
init
.
uniform_
(
self
.
stds
,
0
,
3
)
nn
.
init
.
constant_
(
self
.
gamma
.
weight
,
1
)
nn
.
init
.
constant_
(
self
.
beta
.
weight
,
0
)
def
forward
(
self
,
coord
,
node_type
=
None
):
"""
Parameters
----------
g : DGLGraph
A DGLGraph to be encoded, which must be a homogeneous one.
coord : torch.Tensor
3D coordinates of nodes in :attr:`g`,
of shape :math:`(N, 3)`,
where :math:`N`: is the number of nodes in :attr:`g`.
3D coordinates of nodes in shape :math:`(B, N, 3)`, where :math:`B`
is the batch size, :math:`N`: is the maximum number of nodes.
node_type : torch.Tensor, optional
Node types of
:attr:`g`
. Default : None.
Node type
id
s of
nodes
. Default : None.
* If :attr:`max_node_type` is not 1, :attr:`node_type` needs to
be a tensor in shape :math:`(N,)`. The scaling factors of
each pair of nodes are determined by their node types.
* Otherwise, :attr:`node_type` should be None.
* If specified, :attr:`node_type` should be a tensor in shape
:math:`(B, N,)`. The scaling factors in gaussian kernels of each
pair of nodes are determined by their node types.
* Otherwise, :attr:`node_type` will be set to zeros of the same
shape by default.
Returns
-------
torch.Tensor
Return attention bias as 3D spatial encoding of shape
:math:`(B, n, n, H)`, where :math:`B` is the batch size, :math:`n`
is the maximum number of nodes in unbatched graphs from :attr:`g`,
and :math:`H` is :attr:`num_heads`.
:math:`(B, N, N, H)`, where :math:`H` is :attr:`num_heads`.
"""
device
=
g
.
device
g_list
=
unbatch
(
g
)
max_num_nodes
=
th
.
max
(
g
.
batch_num_nodes
())
spatial_encoding
=
th
.
zeros
(
len
(
g_list
),
max_num_nodes
,
max_num_nodes
,
self
.
num_heads
).
to
(
device
)
sum_num_nodes
=
0
if
(
self
.
max_node_type
==
1
)
!=
(
node_type
is
None
):
raise
ValueError
(
"input node_type should be None if and only if "
"max_node_type is 1."
)
for
i
,
ubg
in
enumerate
(
g_list
):
num_nodes
=
ubg
.
num_nodes
()
sub_coord
=
coord
[
sum_num_nodes
:
sum_num_nodes
+
num_nodes
]
# shape: [n, n], n = num_nodes
euc_dist
=
th
.
cdist
(
sub_coord
,
sub_coord
,
p
=
2
)
bsz
,
N
=
coord
.
shape
[:
2
]
euc_dist
=
th
.
cdist
(
coord
,
coord
,
p
=
2.0
)
# shape: [B, n, n]
if
node_type
is
None
:
# shape: [1]
mul
=
self
.
mul
.
weight
[
0
,
0
]
bias
=
self
.
bias
.
weight
[
0
,
0
]
node_type
=
th
.
zeros
([
bsz
,
N
,
N
,
2
],
device
=
coord
.
device
).
long
()
else
:
sub_node_type
=
node_type
[
sum_num_nodes
:
sum_num_nodes
+
num_nodes
]
mul_embedding
=
self
.
mul
(
sub_node_type
)
bias_embedding
=
self
.
bias
(
sub_node_type
)
# shape: [n, n]
mul
=
mul_embedding
[:,
0
].
unsqueeze
(
-
1
).
repeat
(
1
,
num_nodes
)
+
mul_embedding
[:,
1
].
unsqueeze
(
0
).
repeat
(
num_nodes
,
1
)
bias
=
bias_embedding
[:,
0
].
unsqueeze
(
-
1
).
repeat
(
1
,
num_nodes
)
+
bias_embedding
[:,
1
].
unsqueeze
(
0
).
repeat
(
num_nodes
,
1
)
# shape: [n, n, k], k = num_kernels
scaled_dist
=
(
(
mul
*
euc_dist
+
bias
)
.
repeat
(
self
.
num_kernels
,
1
,
1
)
.
permute
((
1
,
2
,
0
))
)
# shape: [k]
gaussian_mean
=
self
.
gaussian_means
.
weight
.
float
().
view
(
-
1
)
gaussian_var
=
(
self
.
gaussian_stds
.
weight
.
float
().
view
(
-
1
).
abs
()
+
1e-2
)
# shape: [n, n, k]
gaussian_kernel
=
(
(
-
0.5
*
(
th
.
div
(
scaled_dist
-
gaussian_mean
,
gaussian_var
).
square
()
)
)
.
exp
()
.
div
(
-
math
.
sqrt
(
2
*
math
.
pi
)
*
gaussian_var
)
)
src_node_type
=
node_type
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
N
)
tgt_node_type
=
node_type
.
unsqueeze
(
1
).
repeat
(
1
,
N
,
1
)
node_type
=
th
.
stack
(
[
src_node_type
+
2
,
tgt_node_type
+
self
.
max_node_type
+
3
],
dim
=-
1
,
)
# shape: [B, n, n, 2]
# scaled euclidean distance
gamma
=
self
.
gamma
(
node_type
).
sum
(
dim
=-
2
)
# shape: [B, n, n, 1]
beta
=
self
.
beta
(
node_type
).
sum
(
dim
=-
2
)
# shape: [B, n, n, 1]
euc_dist
=
gamma
*
euc_dist
.
unsqueeze
(
-
1
)
+
beta
# shape: [B, n, n, 1]
# gaussian basis kernel
euc_dist
=
euc_dist
.
expand
(
-
1
,
-
1
,
-
1
,
self
.
num_kernels
)
gaussian_kernel
=
gaussian
(
euc_dist
,
self
.
means
,
self
.
stds
.
abs
()
+
1e-2
)
# shape: [B, n, n, K]
# linear projection
encoding
=
self
.
linear_layer_1
(
gaussian_kernel
)
encoding
=
F
.
gelu
(
encoding
)
# [n, n, k] -> [n, n, a], a = num_heads
encoding
=
self
.
linear_layer_2
(
encoding
)
spatial_encoding
[
i
,
:
num_nodes
,
:
num_nodes
]
=
encoding
sum_num_nodes
+=
num_nodes
return
spatial_encoding
encoding
=
self
.
linear_layer_2
(
encoding
)
# shape: [B, n, n, H]
return
encoding
tests/python/pytorch/nn/test_nn.py
View file @
b51fb6f6
...
...
@@ -2547,10 +2547,21 @@ def test_PathEncoder(max_len, feat_dim, num_heads):
@
pytest
.
mark
.
parametrize
(
"max_dist"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"num_kernels"
,
[
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"num_kernels"
,
[
4
,
16
])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
[
1
,
8
])
def
test_SpatialEncoder
(
max_dist
,
num_kernels
,
num_heads
):
dev
=
F
.
ctx
()
# single graph encoding 3d
num_nodes
=
4
coord
=
th
.
rand
(
1
,
num_nodes
,
3
).
to
(
dev
)
node_type
=
th
.
tensor
([[
1
,
0
,
2
,
1
]]).
to
(
dev
)
spatial_encoder
=
nn
.
SpatialEncoder3d
(
num_kernels
=
num_kernels
,
num_heads
=
num_heads
,
max_node_type
=
3
).
to
(
dev
)
out
=
spatial_encoder
(
coord
,
node_type
=
node_type
)
assert
out
.
shape
==
(
1
,
num_nodes
,
num_nodes
,
num_heads
)
# encoding on a batch of graphs
g1
=
dgl
.
graph
(
(
th
.
tensor
([
0
,
0
,
0
,
1
,
1
,
2
,
3
,
3
]),
...
...
@@ -2560,21 +2571,29 @@ def test_SpatialEncoder(max_dist, num_kernels, num_heads):
g2
=
dgl
.
graph
(
(
th
.
tensor
([
0
,
1
,
2
,
3
,
2
,
5
]),
th
.
tensor
([
1
,
2
,
3
,
4
,
0
,
3
]))
).
to
(
dev
)
bg
=
dgl
.
batch
([
g1
,
g2
])
ndata
=
th
.
rand
(
bg
.
num_nodes
(),
3
).
to
(
dev
)
num_nodes
=
bg
.
num_nodes
()
node_type
=
th
.
randint
(
0
,
512
,
(
num_nodes
,)).
to
(
dev
)
dist
=
-
th
.
ones
((
2
,
6
,
6
),
dtype
=
th
.
long
).
to
(
dev
)
bsz
,
max_num_nodes
=
2
,
6
# 2d encoding
dist
=
-
th
.
ones
((
bsz
,
max_num_nodes
,
max_num_nodes
),
dtype
=
th
.
long
).
to
(
dev
)
dist
[
0
,
:
4
,
:
4
]
=
shortest_dist
(
g1
,
root
=
None
,
return_paths
=
False
)
dist
[
1
,
:
6
,
:
6
]
=
shortest_dist
(
g2
,
root
=
None
,
return_paths
=
False
)
model_1
=
nn
.
SpatialEncoder
(
max_dist
,
num_heads
=
num_heads
).
to
(
dev
)
encoding
=
model_1
(
dist
)
assert
encoding
.
shape
==
(
bsz
,
max_num_nodes
,
max_num_nodes
,
num_heads
)
# 3d encoding
coord
=
th
.
rand
(
bsz
,
max_num_nodes
,
3
).
to
(
dev
)
node_type
=
th
.
randint
(
0
,
512
,
(
bsz
,
max_num_nodes
,
),
).
to
(
dev
)
model_2
=
nn
.
SpatialEncoder3d
(
num_kernels
,
num_heads
=
num_heads
).
to
(
dev
)
model_3
=
nn
.
SpatialEncoder3d
(
num_kernels
,
num_heads
=
num_heads
,
max_node_type
=
512
).
to
(
dev
)
encoding
=
model_1
(
dist
)
encoding3d_1
=
model_2
(
bg
,
ndata
)
encoding3d_2
=
model_3
(
bg
,
ndata
,
node_type
)
assert
encoding
.
shape
==
(
2
,
6
,
6
,
num_heads
)
assert
encoding3d_1
.
shape
==
(
2
,
6
,
6
,
num_heads
)
assert
encoding3d_2
.
shape
==
(
2
,
6
,
6
,
num_heads
)
encoding3d_1
=
model_2
(
coord
)
encoding3d_2
=
model_3
(
coord
,
node_type
)
assert
encoding3d_1
.
shape
==
(
bsz
,
max_num_nodes
,
max_num_nodes
,
num_heads
)
assert
encoding3d_2
.
shape
==
(
bsz
,
max_num_nodes
,
max_num_nodes
,
num_heads
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment