Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e47a0279
Unverified
Commit
e47a0279
authored
Jun 09, 2023
by
Zhiteng Li
Committed by
GitHub
Jun 09, 2023
Browse files
[NN] Refactor DegreeEncoder, SpatialEncoder and PathEncoder (#5799)
Co-authored-by:
rudongyu
<
ru_dongyu@outlook.com
>
parent
4fd0a158
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
117 additions
and
117 deletions
+117
-117
python/dgl/nn/pytorch/gt/degree_encoder.py
python/dgl/nn/pytorch/gt/degree_encoder.py
+28
-22
python/dgl/nn/pytorch/gt/path_encoder.py
python/dgl/nn/pytorch/gt/path_encoder.py
+31
-50
python/dgl/nn/pytorch/gt/spatial_encoder.py
python/dgl/nn/pytorch/gt/spatial_encoder.py
+24
-32
tests/python/pytorch/nn/test_nn.py
tests/python/pytorch/nn/test_nn.py
+34
-13
No files found.
python/dgl/nn/pytorch/gt/degree_encoder.py
View file @
e47a0279
...
@@ -3,8 +3,6 @@
...
@@ -3,8 +3,6 @@
import
torch
as
th
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn
as
nn
from
....base
import
DGLError
class
DegreeEncoder
(
nn
.
Module
):
class
DegreeEncoder
(
nn
.
Module
):
r
"""Degree Encoder, as introduced in
r
"""Degree Encoder, as introduced in
...
@@ -31,10 +29,19 @@ class DegreeEncoder(nn.Module):
...
@@ -31,10 +29,19 @@ class DegreeEncoder(nn.Module):
-------
-------
>>> import dgl
>>> import dgl
>>> from dgl.nn import DegreeEncoder
>>> from dgl.nn import DegreeEncoder
>>> import torch as th
>>> from torch.nn.utils.rnn import pad_sequence
>>> g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> g1 = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> g2 = dgl.graph(([0,1], [1,0]))
>>> in_degree = pad_sequence([g1.in_degrees(), g2.in_degrees()], batch_first=True)
>>> out_degree = pad_sequence([g1.out_degrees(), g2.out_degrees()], batch_first=True)
>>> print(in_degree.shape)
torch.Size([2, 4])
>>> degree_encoder = DegreeEncoder(5, 16)
>>> degree_encoder = DegreeEncoder(5, 16)
>>> degree_embedding = degree_encoder(g)
>>> degree_embedding = degree_encoder(th.stack((in_degree, out_degree)))
>>> print(degree_embedding.shape)
torch.Size([2, 4, 16])
"""
"""
def
__init__
(
self
,
max_degree
,
embedding_dim
,
direction
=
"both"
):
def
__init__
(
self
,
max_degree
,
embedding_dim
,
direction
=
"both"
):
...
@@ -53,36 +60,35 @@ class DegreeEncoder(nn.Module):
...
@@ -53,36 +60,35 @@ class DegreeEncoder(nn.Module):
)
)
self
.
max_degree
=
max_degree
self
.
max_degree
=
max_degree
def
forward
(
self
,
g
):
def
forward
(
self
,
degrees
):
"""
"""
Parameters
Parameters
----------
----------
g : DGLGraph
degrees : Tensor
A DGLGraph to be encoded. Graphs with more than one type of edges
If :attr:`direction` is ``both``, it should be stacked in degrees and out degrees
are not allowed.
of the batched graph with zero padding, a tensor of shape :math:`(2, B, N)`.
Otherwise, it should be zero-padded in degrees or out degrees of the batched
graph, a tensor of shape :math:`(B, N)`, where :math:`B` is the batch size
of the batched graph, and :math:`N` is the maximum number of nodes.
Returns
Returns
-------
-------
Tensor
Tensor
Return degree embedding vectors of shape :math:`(N, d)`,
Return degree embedding vectors of shape :math:`(B, N, d)`,
where :math:`N` is the number of nodes in the input graph and
where :math:`d` is :attr:`embedding_dim`.
:math:`d` is :attr:`embedding_dim`.
"""
"""
if
len
(
g
.
etypes
)
>
1
:
degrees
=
th
.
clamp
(
degrees
,
min
=
0
,
max
=
self
.
max_degree
)
raise
DGLError
(
"The input graph should have no more than one type of edges."
)
in_degree
=
th
.
clamp
(
g
.
in_degrees
(),
min
=
0
,
max
=
self
.
max_degree
)
out_degree
=
th
.
clamp
(
g
.
out_degrees
(),
min
=
0
,
max
=
self
.
max_degree
)
if
self
.
direction
==
"in"
:
if
self
.
direction
==
"in"
:
degree_embedding
=
self
.
encoder
(
in_degree
)
assert
len
(
degrees
.
shape
)
==
2
degree_embedding
=
self
.
encoder
(
degrees
)
elif
self
.
direction
==
"out"
:
elif
self
.
direction
==
"out"
:
degree_embedding
=
self
.
encoder
(
out_degree
)
assert
len
(
degrees
.
shape
)
==
2
degree_embedding
=
self
.
encoder
(
degrees
)
elif
self
.
direction
==
"both"
:
elif
self
.
direction
==
"both"
:
degree_embedding
=
self
.
encoder1
(
in_degree
)
+
self
.
encoder2
(
assert
len
(
degrees
.
shape
)
==
3
and
degrees
.
shape
[
0
]
==
2
out_degree
degree_embedding
=
self
.
encoder1
(
degrees
[
0
])
+
self
.
encoder2
(
degrees
[
1
]
)
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
...
python/dgl/nn/pytorch/gt/path_encoder.py
View file @
e47a0279
...
@@ -2,9 +2,6 @@
...
@@ -2,9 +2,6 @@
import
torch
as
th
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn
as
nn
from
....batch
import
unbatch
from
....transforms
import
shortest_dist
class
PathEncoder
(
nn
.
Module
):
class
PathEncoder
(
nn
.
Module
):
r
"""Path Encoder, as introduced in Edge Encoding of
r
"""Path Encoder, as introduced in Edge Encoding of
...
@@ -31,13 +28,21 @@ class PathEncoder(nn.Module):
...
@@ -31,13 +28,21 @@ class PathEncoder(nn.Module):
>>> import torch as th
>>> import torch as th
>>> import dgl
>>> import dgl
>>> from dgl.nn import PathEncoder
>>> from dgl.nn import PathEncoder
>>> from dgl import shortest_dist
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
>>> g = dgl.graph((u, v))
>>> edata = th.rand(8, 16)
>>> edata = th.rand(8, 16)
>>> # Since shortest_dist returns -1 for unreachable node pairs,
>>> # edata[-1] should be filled with zero padding.
>>> edata = th.cat(
(edata, th.zeros(1, 16)), dim=0
)
>>> dist, path = shortest_dist(g, root=None, return_paths=True)
>>> path_data = edata[path[:, :, :2]]
>>> path_encoder = PathEncoder(2, 16, num_heads=8)
>>> path_encoder = PathEncoder(2, 16, num_heads=8)
>>> out = path_encoder(g, edata)
>>> out = path_encoder(dist.unsqueeze(0), path_data.unsqueeze(0))
>>> print(out.shape)
torch.Size([1, 4, 4, 8])
"""
"""
def
__init__
(
self
,
max_len
,
feat_dim
,
num_heads
=
1
):
def
__init__
(
self
,
max_len
,
feat_dim
,
num_heads
=
1
):
...
@@ -47,16 +52,18 @@ class PathEncoder(nn.Module):
...
@@ -47,16 +52,18 @@ class PathEncoder(nn.Module):
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
embedding_table
=
nn
.
Embedding
(
max_len
*
num_heads
,
feat_dim
)
self
.
embedding_table
=
nn
.
Embedding
(
max_len
*
num_heads
,
feat_dim
)
def
forward
(
self
,
g
,
edge_fe
at
):
def
forward
(
self
,
dist
,
path_d
at
a
):
"""
"""
Parameters
Parameters
----------
----------
g : DGLGraph
dist : Tensor
A DGLGraph to be encoded, which must be a homogeneous one.
Shortest path distance matrix of the batched graph with zero padding,
edge_feat : torch.Tensor
of shape :math:`(B, N, N)`, where :math:`B` is the batch size of
The input edge feature of shape :math:`(E, d)`,
the batched graph, and :math:`N` is the maximum number of nodes.
where :math:`E` is the number of edges in the input graph and
path_data : Tensor
:math:`d` is :attr:`feat_dim`.
Edge feature along the shortest path with zero padding, of shape
:math:`(B, N, N, L, d)`, where :math:`L` is the maximum length of
the shortest paths, and :math:`d` is :attr:`feat_dim`.
Returns
Returns
-------
-------
...
@@ -66,40 +73,14 @@ class PathEncoder(nn.Module):
...
@@ -66,40 +73,14 @@ class PathEncoder(nn.Module):
the input graph, :math:`N` is the maximum number of nodes, and
the input graph, :math:`N` is the maximum number of nodes, and
:math:`H` is :attr:`num_heads`.
:math:`H` is :attr:`num_heads`.
"""
"""
device
=
g
.
device
shortest_distance
=
th
.
clamp
(
dist
,
min
=
1
,
max
=
self
.
max_len
)
g_list
=
unbatch
(
g
)
edge_embedding
=
self
.
embedding_table
.
weight
.
reshape
(
sum_num_edges
=
0
self
.
max_len
,
self
.
num_heads
,
-
1
max_num_nodes
=
th
.
max
(
g
.
batch_num_nodes
())
)
path_encoding
=
th
.
zeros
(
path_encoding
=
th
.
div
(
len
(
g_list
),
max_num_nodes
,
max_num_nodes
,
self
.
num_heads
th
.
einsum
(
"bxyld,lhd->bxyh"
,
path_data
,
edge_embedding
).
permute
(
).
to
(
device
)
3
,
0
,
1
,
2
),
for
i
,
ubg
in
enumerate
(
g_list
):
shortest_distance
,
num_nodes
=
ubg
.
num_nodes
()
).
permute
(
1
,
2
,
3
,
0
)
num_edges
=
ubg
.
num_edges
()
edata
=
edge_feat
[
sum_num_edges
:
(
sum_num_edges
+
num_edges
)]
sum_num_edges
=
sum_num_edges
+
num_edges
edata
=
th
.
cat
(
(
edata
,
th
.
zeros
(
1
,
self
.
feat_dim
).
to
(
edata
.
device
)),
dim
=
0
)
dist
,
path
=
shortest_dist
(
ubg
,
root
=
None
,
return_paths
=
True
)
path_len
=
max
(
1
,
min
(
self
.
max_len
,
path
.
size
(
dim
=
2
)))
# shape: [n, n, l], n = num_nodes, l = path_len
shortest_path
=
path
[:,
:,
0
:
path_len
]
# shape: [n, n]
shortest_distance
=
th
.
clamp
(
dist
,
min
=
1
,
max
=
path_len
)
# shape: [n, n, l, d], d = feat_dim
path_data
=
edata
[
shortest_path
]
# shape: [l, h, d]
edge_embedding
=
self
.
embedding_table
.
weight
[
0
:
path_len
*
self
.
num_heads
].
reshape
(
path_len
,
self
.
num_heads
,
-
1
)
# [n, n, l, d] einsum [l, h, d] -> [n, n, h]
path_encoding
[
i
,
:
num_nodes
,
:
num_nodes
]
=
th
.
div
(
th
.
einsum
(
"xyld,lhd->xyh"
,
path_data
,
edge_embedding
).
permute
(
2
,
0
,
1
),
shortest_distance
,
).
permute
(
1
,
2
,
0
)
return
path_encoding
return
path_encoding
python/dgl/nn/pytorch/gt/spatial_encoder.py
View file @
e47a0279
...
@@ -7,7 +7,6 @@ import torch.nn as nn
...
@@ -7,7 +7,6 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
....batch
import
unbatch
from
....batch
import
unbatch
from
....transforms
import
shortest_dist
class
SpatialEncoder
(
nn
.
Module
):
class
SpatialEncoder
(
nn
.
Module
):
...
@@ -33,14 +32,19 @@ class SpatialEncoder(nn.Module):
...
@@ -33,14 +32,19 @@ class SpatialEncoder(nn.Module):
>>> import torch as th
>>> import torch as th
>>> import dgl
>>> import dgl
>>> from dgl.nn import SpatialEncoder
>>> from dgl.nn import SpatialEncoder
>>> from dgl import shortest_dist
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
>>> g1 = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> g = dgl.graph((u, v))
>>> g2 = dgl.graph(([0,1], [1,0]))
>>> n1, n2 = g1.num_nodes(), g2.num_nodes()
>>> # use -1 padding since shortest_dist returns -1 for unreachable node pairs
>>> dist = -th.ones((2, 4, 4), dtype=th.long)
>>> dist[0, :n1, :n1] = shortest_dist(g1, root=None, return_paths=False)
>>> dist[1, :n2, :n2] = shortest_dist(g2, root=None, return_paths=False)
>>> spatial_encoder = SpatialEncoder(max_dist=2, num_heads=8)
>>> spatial_encoder = SpatialEncoder(max_dist=2, num_heads=8)
>>> out = spatial_encoder(
g
)
>>> out = spatial_encoder(
dist
)
>>> print(out.shape)
>>> print(out.shape)
torch.Size([
1
, 4, 4, 8])
torch.Size([
2
, 4, 4, 8])
"""
"""
def
__init__
(
self
,
max_dist
,
num_heads
=
1
):
def
__init__
(
self
,
max_dist
,
num_heads
=
1
):
...
@@ -52,41 +56,29 @@ class SpatialEncoder(nn.Module):
...
@@ -52,41 +56,29 @@ class SpatialEncoder(nn.Module):
max_dist
+
2
,
num_heads
,
padding_idx
=
0
max_dist
+
2
,
num_heads
,
padding_idx
=
0
)
)
def
forward
(
self
,
g
):
def
forward
(
self
,
dist
):
"""
"""
Parameters
Parameters
----------
----------
g : DGLGraph
dist : Tensor
A DGLGraph to be encoded, which must be a homogeneous one.
Shortest path distance of the batched graph with -1 padding, a tensor
of shape :math:`(B, N, N)`, where :math:`B` is the batch size of
the batched graph, and :math:`N` is the maximum number of nodes.
Returns
Returns
-------
-------
torch.Tensor
torch.Tensor
Return attention bias as spatial encoding of shape
Return attention bias as spatial encoding of shape
:math:`(B, N, N, H)`, where :math:`N` is the maximum number of
:math:`(B, N, N, H)`, where :math:`H` is :attr:`num_heads`.
nodes, :math:`B` is the batch size of the input graph, and
:math:`H` is :attr:`num_heads`.
"""
"""
device
=
g
.
device
spatial_encoding
=
self
.
embedding_table
(
g_list
=
unbatch
(
g
)
th
.
clamp
(
max_num_nodes
=
th
.
max
(
g
.
batch_num_nodes
())
dist
,
spatial_encoding
=
th
.
zeros
(
min
=-
1
,
len
(
g_list
),
max_num_nodes
,
max_num_nodes
,
self
.
num_heads
max
=
self
.
max_dist
,
).
to
(
device
)
for
i
,
ubg
in
enumerate
(
g_list
):
num_nodes
=
ubg
.
num_nodes
()
dist
=
(
th
.
clamp
(
shortest_dist
(
ubg
,
root
=
None
,
return_paths
=
False
),
min
=-
1
,
max
=
self
.
max_dist
,
)
+
1
)
)
# shape: [n, n, h], n = num_nodes, h = num_heads
+
1
dist_embedding
=
self
.
embedding_table
(
dist
)
)
spatial_encoding
[
i
,
:
num_nodes
,
:
num_nodes
]
=
dist_embedding
return
spatial_encoding
return
spatial_encoding
...
...
tests/python/pytorch/nn/test_nn.py
View file @
e47a0279
...
@@ -12,6 +12,8 @@ import pytest
...
@@ -12,6 +12,8 @@ import pytest
import
scipy
as
sp
import
scipy
as
sp
import
torch
import
torch
import
torch
as
th
import
torch
as
th
from
dgl
import
shortest_dist
from
torch.nn.utils.rnn
import
pad_sequence
from
torch.optim
import
Adam
,
SparseAdam
from
torch.optim
import
Adam
,
SparseAdam
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
utils
import
parametrize_idtype
from
utils
import
parametrize_idtype
...
@@ -2389,15 +2391,32 @@ def test_DeepWalk():
...
@@ -2389,15 +2391,32 @@ def test_DeepWalk():
@
pytest
.
mark
.
parametrize
(
"embedding_dim"
,
[
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"embedding_dim"
,
[
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"direction"
,
[
"in"
,
"out"
,
"both"
])
@
pytest
.
mark
.
parametrize
(
"direction"
,
[
"in"
,
"out"
,
"both"
])
def
test_degree_encoder
(
max_degree
,
embedding_dim
,
direction
):
def
test_degree_encoder
(
max_degree
,
embedding_dim
,
direction
):
g
=
dgl
.
graph
(
g
1
=
dgl
.
graph
(
(
(
th
.
tensor
([
0
,
0
,
0
,
1
,
1
,
2
,
3
,
3
]),
th
.
tensor
([
0
,
0
,
0
,
1
,
1
,
2
,
3
,
3
]),
th
.
tensor
([
1
,
2
,
3
,
0
,
3
,
0
,
0
,
1
]),
th
.
tensor
([
1
,
2
,
3
,
0
,
3
,
0
,
0
,
1
]),
)
)
)
)
g2
=
dgl
.
graph
(
(
th
.
tensor
([
0
,
1
]),
th
.
tensor
([
1
,
0
]),
)
)
in_degree
=
pad_sequence
(
[
g1
.
in_degrees
(),
g2
.
in_degrees
()],
batch_first
=
True
)
out_degree
=
pad_sequence
(
[
g1
.
out_degrees
(),
g2
.
out_degrees
()],
batch_first
=
True
)
model
=
nn
.
DegreeEncoder
(
max_degree
,
embedding_dim
,
direction
=
direction
)
model
=
nn
.
DegreeEncoder
(
max_degree
,
embedding_dim
,
direction
=
direction
)
de_g
=
model
(
g
)
if
direction
==
"in"
:
assert
de_g
.
shape
==
(
4
,
embedding_dim
)
de_g
=
model
(
in_degree
)
elif
direction
==
"out"
:
de_g
=
model
(
out_degree
)
elif
direction
==
"both"
:
de_g
=
model
(
th
.
stack
((
in_degree
,
out_degree
)))
assert
de_g
.
shape
==
(
2
,
4
,
embedding_dim
)
@
parametrize_idtype
@
parametrize_idtype
...
@@ -2498,25 +2517,24 @@ def test_GraphormerLayer(attn_bias_type, norm_first):
...
@@ -2498,25 +2517,24 @@ def test_GraphormerLayer(attn_bias_type, norm_first):
assert
out
.
shape
==
(
batch_size
,
num_nodes
,
feat_size
)
assert
out
.
shape
==
(
batch_size
,
num_nodes
,
feat_size
)
@
pytest
.
mark
.
parametrize
(
"max_len"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"max_len"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"feat_dim"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"feat_dim"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
[
1
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
[
1
,
8
])
def
test_PathEncoder
(
max_len
,
feat_dim
,
num_heads
):
def
test_PathEncoder
(
max_len
,
feat_dim
,
num_heads
):
dev
=
F
.
ctx
()
dev
=
F
.
ctx
()
g
1
=
dgl
.
graph
(
g
=
dgl
.
graph
(
(
(
th
.
tensor
([
0
,
0
,
0
,
1
,
1
,
2
,
3
,
3
]),
th
.
tensor
([
0
,
0
,
0
,
1
,
1
,
2
,
3
,
3
]),
th
.
tensor
([
1
,
2
,
3
,
0
,
3
,
0
,
0
,
1
]),
th
.
tensor
([
1
,
2
,
3
,
0
,
3
,
0
,
0
,
1
]),
)
)
).
to
(
dev
)
).
to
(
dev
)
g2
=
dgl
.
graph
(
edge_feat
=
th
.
rand
(
g
.
num_edges
(),
feat_dim
).
to
(
dev
)
(
th
.
tensor
([
0
,
1
,
2
,
3
,
2
,
5
]),
th
.
tensor
([
1
,
2
,
3
,
4
,
0
,
3
]))
edge_feat
=
th
.
cat
((
edge_feat
,
th
.
zeros
(
1
,
16
).
to
(
dev
)),
dim
=
0
)
).
to
(
dev
)
dist
,
path
=
shortest_dist
(
g
,
root
=
None
,
return_paths
=
True
)
bg
=
dgl
.
batch
([
g1
,
g2
])
path_data
=
edge_feat
[
path
[:,
:,
:
max_len
]]
edge_feat
=
th
.
rand
(
bg
.
num_edges
(),
feat_dim
).
to
(
dev
)
model
=
nn
.
PathEncoder
(
max_len
,
feat_dim
,
num_heads
=
num_heads
).
to
(
dev
)
model
=
nn
.
PathEncoder
(
max_len
,
feat_dim
,
num_heads
=
num_heads
).
to
(
dev
)
bias
=
model
(
bg
,
edge_feat
)
bias
=
model
(
dist
.
unsqueeze
(
0
),
path_data
.
unsqueeze
(
0
)
)
assert
bias
.
shape
==
(
2
,
6
,
6
,
num_heads
)
assert
bias
.
shape
==
(
1
,
4
,
4
,
num_heads
)
@
pytest
.
mark
.
parametrize
(
"max_dist"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"max_dist"
,
[
1
,
4
])
...
@@ -2537,12 +2555,15 @@ def test_SpatialEncoder(max_dist, num_kernels, num_heads):
...
@@ -2537,12 +2555,15 @@ def test_SpatialEncoder(max_dist, num_kernels, num_heads):
ndata
=
th
.
rand
(
bg
.
num_nodes
(),
3
).
to
(
dev
)
ndata
=
th
.
rand
(
bg
.
num_nodes
(),
3
).
to
(
dev
)
num_nodes
=
bg
.
num_nodes
()
num_nodes
=
bg
.
num_nodes
()
node_type
=
th
.
randint
(
0
,
512
,
(
num_nodes
,)).
to
(
dev
)
node_type
=
th
.
randint
(
0
,
512
,
(
num_nodes
,)).
to
(
dev
)
dist
=
-
th
.
ones
((
2
,
6
,
6
),
dtype
=
th
.
long
).
to
(
dev
)
dist
[
0
,
:
4
,
:
4
]
=
shortest_dist
(
g1
,
root
=
None
,
return_paths
=
False
)
dist
[
1
,
:
6
,
:
6
]
=
shortest_dist
(
g2
,
root
=
None
,
return_paths
=
False
)
model_1
=
nn
.
SpatialEncoder
(
max_dist
,
num_heads
=
num_heads
).
to
(
dev
)
model_1
=
nn
.
SpatialEncoder
(
max_dist
,
num_heads
=
num_heads
).
to
(
dev
)
model_2
=
nn
.
SpatialEncoder3d
(
num_kernels
,
num_heads
=
num_heads
).
to
(
dev
)
model_2
=
nn
.
SpatialEncoder3d
(
num_kernels
,
num_heads
=
num_heads
).
to
(
dev
)
model_3
=
nn
.
SpatialEncoder3d
(
model_3
=
nn
.
SpatialEncoder3d
(
num_kernels
,
num_heads
=
num_heads
,
max_node_type
=
512
num_kernels
,
num_heads
=
num_heads
,
max_node_type
=
512
).
to
(
dev
)
).
to
(
dev
)
encoding
=
model_1
(
bg
)
encoding
=
model_1
(
dist
)
encoding3d_1
=
model_2
(
bg
,
ndata
)
encoding3d_1
=
model_2
(
bg
,
ndata
)
encoding3d_2
=
model_3
(
bg
,
ndata
,
node_type
)
encoding3d_2
=
model_3
(
bg
,
ndata
,
node_type
)
assert
encoding
.
shape
==
(
2
,
6
,
6
,
num_heads
)
assert
encoding
.
shape
==
(
2
,
6
,
6
,
num_heads
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment