"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "606ae3248e2a553aa14fbc69a632dd3db5092b1c"
Unverified Commit cd817a1a authored by ZhenyuLU_Heliodore's avatar ZhenyuLU_Heliodore Committed by GitHub
Browse files

[NN] Add SpatialEncoder and SpatialEncoder3d (#4991)



* Add SpatialEncoder and SpatialEncoder3d

* Optimize the code execution efficiency

* Fixed certain problems according to Dongyu's suggestions.

* Fix an error about probability of division by zero in PathEcoder; Change certain designs in SpatialEncoder

* Fix a typo

* polish the docstring

* fix doc

* lint
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-14-146.ap-northeast-1.compute.internal>
Co-authored-by: default avatarrudongyu <ru_dongyu@outlook.com>
parent ce378327
...@@ -120,6 +120,8 @@ Utility Modules ...@@ -120,6 +120,8 @@ Utility Modules
~dgl.nn.pytorch.graph_transformer.BiasedMultiheadAttention ~dgl.nn.pytorch.graph_transformer.BiasedMultiheadAttention
~dgl.nn.pytorch.graph_transformer.GraphormerLayer ~dgl.nn.pytorch.graph_transformer.GraphormerLayer
~dgl.nn.pytorch.graph_transformer.PathEncoder ~dgl.nn.pytorch.graph_transformer.PathEncoder
~dgl.nn.pytorch.graph_transformer.SpatialEncoder
~dgl.nn.pytorch.graph_transformer.SpatialEncoder3d
Network Embedding Modules Network Embedding Modules
---------------------------------------- ----------------------------------------
......
This diff is collapsed.
...@@ -1844,3 +1844,32 @@ def test_PathEncoder(max_len, feat_dim, num_heads): ...@@ -1844,3 +1844,32 @@ def test_PathEncoder(max_len, feat_dim, num_heads):
model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev) model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev)
bias = model(bg, edge_feat) bias = model(bg, edge_feat)
assert bias.shape == (2, 6, 6, num_heads) assert bias.shape == (2, 6, 6, num_heads)
@pytest.mark.parametrize('max_dist', [1, 4])
@pytest.mark.parametrize('num_kernels', [8, 16])
@pytest.mark.parametrize('num_heads', [1, 8])
def test_SpatialEncoder(max_dist, num_kernels, num_heads):
dev = F.ctx()
g1 = dgl.graph((
th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
)).to(dev)
g2 = dgl.graph((
th.tensor([0, 1, 2, 3, 2, 5]),
th.tensor([1, 2, 3, 4, 0, 3])
)).to(dev)
bg = dgl.batch([g1, g2])
ndata = th.rand(bg.num_nodes(), 3).to(dev)
num_nodes = bg.num_nodes()
node_type = th.randint(0, 512, (num_nodes,)).to(dev)
model_1 = nn.SpatialEncoder(max_dist, num_heads=num_heads).to(dev)
model_2 = nn.SpatialEncoder3d(num_kernels, num_heads=num_heads).to(dev)
model_3 = nn.SpatialEncoder3d(
num_kernels, num_heads=num_heads, max_node_type=512
).to(dev)
encoding = model_1(bg)
encoding3d_1 = model_2(bg, ndata)
encoding3d_2 = model_3(bg, ndata, node_type)
assert encoding.shape == (2, 6, 6, num_heads)
assert encoding3d_1.shape == (2, 6, 6, num_heads)
assert encoding3d_2.shape == (2, 6, 6, num_heads)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment