Unverified Commit cd817a1a authored by ZhenyuLU_Heliodore's avatar ZhenyuLU_Heliodore Committed by GitHub
Browse files

[NN] Add SpatialEncoder and SpatialEncoder3d (#4991)



* Add SpatialEncoder and SpatialEncoder3d

* Optimize the code execution efficiency

* Fixed certain problems according to Dongyu's suggestions.

* Fix an error about probability of division by zero in PathEcoder; Change certain designs in SpatialEncoder

* Fix a typo

* polish the docstring

* fix doc

* lint
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-14-146.ap-northeast-1.compute.internal>
Co-authored-by: default avatarrudongyu <ru_dongyu@outlook.com>
parent ce378327
...@@ -120,6 +120,8 @@ Utility Modules ...@@ -120,6 +120,8 @@ Utility Modules
~dgl.nn.pytorch.graph_transformer.BiasedMultiheadAttention ~dgl.nn.pytorch.graph_transformer.BiasedMultiheadAttention
~dgl.nn.pytorch.graph_transformer.GraphormerLayer ~dgl.nn.pytorch.graph_transformer.GraphormerLayer
~dgl.nn.pytorch.graph_transformer.PathEncoder ~dgl.nn.pytorch.graph_transformer.PathEncoder
~dgl.nn.pytorch.graph_transformer.SpatialEncoder
~dgl.nn.pytorch.graph_transformer.SpatialEncoder3d
Network Embedding Modules Network Embedding Modules
---------------------------------------- ----------------------------------------
......
"""Torch modules for graph transformers.""" """Torch modules for graph transformers."""
import math
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ...convert import to_homogeneous
from ...batch import unbatch from ...batch import unbatch
from ...convert import to_homogeneous
from ...transforms import shortest_dist from ...transforms import shortest_dist
__all__ = [ __all__ = [
"DegreeEncoder", "DegreeEncoder",
"PathEncoder",
"BiasedMultiheadAttention", "BiasedMultiheadAttention",
"GraphormerLayer" "PathEncoder",
"GraphormerLayer",
"SpatialEncoder",
"SpatialEncoder3d",
] ]
class DegreeEncoder(nn.Module): class DegreeEncoder(nn.Module):
r"""Degree Encoder, as introduced in r"""Degree Encoder, as introduced in
`Do Transformers Really Perform Bad for Graph Representation? `Do Transformers Really Perform Bad for Graph Representation?
...@@ -83,14 +89,14 @@ class DegreeEncoder(nn.Module): ...@@ -83,14 +89,14 @@ class DegreeEncoder(nn.Module):
elif self.direction == "out": elif self.direction == "out":
degree_embedding = self.degree_encoder(out_degree) degree_embedding = self.degree_encoder(out_degree)
elif self.direction == "both": elif self.direction == "both":
degree_embedding = (self.degree_encoder_1(in_degree) degree_embedding = self.degree_encoder_1(
+ self.degree_encoder_2(out_degree)) in_degree
) + self.degree_encoder_2(out_degree)
else: else:
raise ValueError( raise ValueError(
f'Supported direction options: "in", "out" and "both", ' f'Supported direction options: "in", "out" and "both", '
f'but got {self.direction}' f"but got {self.direction}"
) )
return degree_embedding return degree_embedding
...@@ -117,12 +123,13 @@ class PathEncoder(nn.Module): ...@@ -117,12 +123,13 @@ class PathEncoder(nn.Module):
-------- --------
>>> import torch as th >>> import torch as th
>>> import dgl >>> import dgl
>>> from dgl.nn import PathEncoder
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3]) >>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1]) >>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
>>> g = dgl.graph((u, v)) >>> g = dgl.graph((u, v))
>>> edata = th.rand(8, 16) >>> edata = th.rand(8, 16)
>>> path_encoder = dgl.PathEncoder(2, 16, 8) >>> path_encoder = PathEncoder(2, 16, num_heads=8)
>>> out = path_encoder(g, edata) >>> out = path_encoder(g, edata)
""" """
...@@ -160,49 +167,36 @@ class PathEncoder(nn.Module): ...@@ -160,49 +167,36 @@ class PathEncoder(nn.Module):
for ubg in g_list: for ubg in g_list:
num_nodes = ubg.num_nodes() num_nodes = ubg.num_nodes()
num_edges = ubg.num_edges() num_edges = ubg.num_edges()
edata = edge_feat[sum_num_edges: (sum_num_edges + num_edges)] edata = edge_feat[sum_num_edges : (sum_num_edges + num_edges)]
sum_num_edges = sum_num_edges + num_edges sum_num_edges = sum_num_edges + num_edges
edata = th.cat( edata = th.cat(
(edata, th.zeros(1, self.feat_dim).to(edata.device)), (edata, th.zeros(1, self.feat_dim).to(edata.device)), dim=0
dim=0
) )
_, path = shortest_dist(ubg, root=None, return_paths=True) dist, path = shortest_dist(ubg, root=None, return_paths=True)
path_len = min(self.max_len, path.size(dim=2)) path_len = max(1, min(self.max_len, path.size(dim=2)))
# shape: [n, n, l], n = num_nodes, l = path_len # shape: [n, n, l], n = num_nodes, l = path_len
shortest_path = path[:, :, 0: path_len] shortest_path = path[:, :, 0:path_len]
# shape: [n, n] # shape: [n, n]
shortest_distance = th.clamp( shortest_distance = th.clamp(dist, min=1, max=path_len)
shortest_dist(ubg, root=None, return_paths=False),
min=1,
max=path_len
)
# shape: [n, n, l, d], d = feat_dim # shape: [n, n, l, d], d = feat_dim
path_data = edata[shortest_path] path_data = edata[shortest_path]
# shape: [l, h], h = num_heads # shape: [l, h, d]
embedding_idx = th.reshape( edge_embedding = self.embedding_table.weight[
th.arange(self.num_heads * path_len), 0 : path_len * self.num_heads
(path_len, self.num_heads) ].reshape(path_len, self.num_heads, -1)
).to(next(self.embedding_table.parameters()).device) # [n, n, l, d] einsum [l, h, d] -> [n, n, h]
# shape: [d, l, h]
edge_embedding = th.permute(
self.embedding_table(embedding_idx), (2, 0, 1)
)
# [n, n, l, d] einsum [d, l, h] -> [n, n, h]
# [n, n, h] -> [N, N, h], N = max_num_nodes, padded with -inf # [n, n, h] -> [N, N, h], N = max_num_nodes, padded with -inf
sub_encoding = th.full( sub_encoding = th.full(
(max_num_nodes, max_num_nodes, self.num_heads), (max_num_nodes, max_num_nodes, self.num_heads), float("-inf")
float('-inf')
) )
sub_encoding[0: num_nodes, 0: num_nodes] = th.div( sub_encoding[0:num_nodes, 0:num_nodes] = th.div(
th.einsum( th.einsum("xyld,lhd->xyh", path_data, edge_embedding).permute(
'xyld,dlh->xyh', path_data, edge_embedding 2, 0, 1
).permute(2, 0, 1), ),
shortest_distance shortest_distance,
).permute(1, 2, 0) ).permute(1, 2, 0)
path_encoding.append(sub_encoding) path_encoding.append(sub_encoding)
return th.stack(path_encoding, dim=0) return th.stack(path_encoding, dim=0)
...@@ -249,7 +243,14 @@ class BiasedMultiheadAttention(nn.Module): ...@@ -249,7 +243,14 @@ class BiasedMultiheadAttention(nn.Module):
>>> out = net(ndata, bias) >>> out = net(ndata, bias)
""" """
def __init__(self, feat_size, num_heads, bias=True, attn_bias_type="add", attn_drop=0.1): def __init__(
self,
feat_size,
num_heads,
bias=True,
attn_bias_type="add",
attn_drop=0.1,
):
super().__init__() super().__init__()
self.feat_size = feat_size self.feat_size = feat_size
self.num_heads = num_heads self.num_heads = num_heads
...@@ -270,8 +271,7 @@ class BiasedMultiheadAttention(nn.Module): ...@@ -270,8 +271,7 @@ class BiasedMultiheadAttention(nn.Module):
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
"""Reset parameters of projection matrices, the same settings as that in Graphormer. """Reset parameters of projection matrices, the same settings as that in Graphormer."""
"""
nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-0.5) nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-0.5)
nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-0.5) nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-0.5)
nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-0.5) nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-0.5)
...@@ -304,9 +304,16 @@ class BiasedMultiheadAttention(nn.Module): ...@@ -304,9 +304,16 @@ class BiasedMultiheadAttention(nn.Module):
k_h = self.k_proj(ndata).transpose(0, 1) k_h = self.k_proj(ndata).transpose(0, 1)
v_h = self.v_proj(ndata).transpose(0, 1) v_h = self.v_proj(ndata).transpose(0, 1)
bsz, N, _ = ndata.shape bsz, N, _ = ndata.shape
q_h = q_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose(0, 1) / self.scaling q_h = (
k_h = k_h.reshape(N, bsz * self.num_heads, self.head_dim).permute(1, 2, 0) q_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose(0, 1)
v_h = v_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose(0, 1) / self.scaling
)
k_h = k_h.reshape(N, bsz * self.num_heads, self.head_dim).permute(
1, 2, 0
)
v_h = v_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose(
0, 1
)
attn_weights = ( attn_weights = (
th.bmm(q_h, k_h) th.bmm(q_h, k_h)
...@@ -320,10 +327,8 @@ class BiasedMultiheadAttention(nn.Module): ...@@ -320,10 +327,8 @@ class BiasedMultiheadAttention(nn.Module):
attn_weights += attn_bias attn_weights += attn_bias
else: else:
attn_weights *= attn_bias attn_weights *= attn_bias
if attn_mask is not None: if attn_mask is not None:
attn_weights[attn_mask.to(th.bool)] = float("-inf") attn_weights[attn_mask.to(th.bool)] = float("-inf")
attn_weights = F.softmax( attn_weights = F.softmax(
attn_weights.transpose(0, 2) attn_weights.transpose(0, 2)
.reshape(N, N, bsz * self.num_heads) .reshape(N, N, bsz * self.num_heads)
...@@ -335,7 +340,9 @@ class BiasedMultiheadAttention(nn.Module): ...@@ -335,7 +340,9 @@ class BiasedMultiheadAttention(nn.Module):
attn = th.bmm(attn_weights, v_h).transpose(0, 1) attn = th.bmm(attn_weights, v_h).transpose(0, 1)
attn = self.out_proj(attn.reshape(N, bsz, self.feat_size).transpose(0, 1)) attn = self.out_proj(
attn.reshape(N, bsz, self.feat_size).transpose(0, 1)
)
return attn return attn
...@@ -392,10 +399,10 @@ class GraphormerLayer(nn.Module): ...@@ -392,10 +399,10 @@ class GraphormerLayer(nn.Module):
feat_size, feat_size,
hidden_size, hidden_size,
num_heads, num_heads,
attn_bias_type='add', attn_bias_type="add",
norm_first=False, norm_first=False,
dropout=0.1, dropout=0.1,
activation=nn.ReLU() activation=nn.ReLU(),
): ):
super().__init__() super().__init__()
...@@ -405,14 +412,14 @@ class GraphormerLayer(nn.Module): ...@@ -405,14 +412,14 @@ class GraphormerLayer(nn.Module):
feat_size=feat_size, feat_size=feat_size,
num_heads=num_heads, num_heads=num_heads,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_drop=dropout attn_drop=dropout,
) )
self.ffn = nn.Sequential( self.ffn = nn.Sequential(
nn.Linear(feat_size, hidden_size), nn.Linear(feat_size, hidden_size),
activation, activation,
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
nn.Linear(hidden_size, feat_size), nn.Linear(hidden_size, feat_size),
nn.Dropout(p=dropout) nn.Dropout(p=dropout),
) )
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
...@@ -447,7 +454,6 @@ class GraphormerLayer(nn.Module): ...@@ -447,7 +454,6 @@ class GraphormerLayer(nn.Module):
nfeat = residual + nfeat nfeat = residual + nfeat
if not self.norm_first: if not self.norm_first:
nfeat = self.attn_layer_norm(nfeat) nfeat = self.attn_layer_norm(nfeat)
residual = nfeat residual = nfeat
if self.norm_first: if self.norm_first:
nfeat = self.ffn_layer_norm(nfeat) nfeat = self.ffn_layer_norm(nfeat)
...@@ -455,5 +461,251 @@ class GraphormerLayer(nn.Module): ...@@ -455,5 +461,251 @@ class GraphormerLayer(nn.Module):
nfeat = residual + nfeat nfeat = residual + nfeat
if not self.norm_first: if not self.norm_first:
nfeat = self.ffn_layer_norm(nfeat) nfeat = self.ffn_layer_norm(nfeat)
return nfeat return nfeat
class SpatialEncoder(nn.Module):
r"""Spatial Encoder, as introduced in
`Do Transformers Really Perform Bad for Graph Representation?
<https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__
This module is a learnable spatial embedding module which encodes
the shortest distance between each node pair for attention bias.
Parameters
----------
max_dist : int
Upper bound of the shortest path distance
between each node pair to be encoded.
All distance will be clamped into the range `[0, max_dist]`.
num_heads : int, optional
Number of attention heads if multi-head attention mechanism is applied.
Default : 1.
Examples
--------
>>> import torch as th
>>> import dgl
>>> from dgl.nn import SpatialEncoder
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
>>> g = dgl.graph((u, v))
>>> spatial_encoder = SpatialEncoder(max_dist=2, num_heads=8)
>>> out = spatial_encoder(g)
>>> print(out.shape)
torch.Size([1, 4, 4, 8])
"""
def __init__(self, max_dist, num_heads=1):
super().__init__()
self.max_dist = max_dist
self.num_heads = num_heads
# deactivate node pair between which the distance is -1
self.embedding_table = nn.Embedding(
max_dist + 2, num_heads, padding_idx=0
)
def forward(self, g):
"""
Parameters
----------
g : DGLGraph
A DGLGraph to be encoded, which must be a homogeneous one.
Returns
-------
torch.Tensor
Return attention bias as spatial encoding of shape
:math:`(B, N, N, H)`, where :math:`N` is the maximum number of
nodes, :math:`B` is the batch size of the input graph, and
:math:`H` is :attr:`num_heads`.
"""
device = g.device
g_list = unbatch(g)
max_num_nodes = th.max(g.batch_num_nodes())
spatial_encoding = []
for ubg in g_list:
num_nodes = ubg.num_nodes()
dist = (
th.clamp(
shortest_dist(ubg, root=None, return_paths=False),
min=-1,
max=self.max_dist,
)
+ 1
)
# shape: [n, n, h], n = num_nodes, h = num_heads
dist_embedding = self.embedding_table(dist)
# [n, n, h] -> [N, N, h], N = max_num_nodes, padded with -inf
padded_encoding = th.full(
(max_num_nodes, max_num_nodes, self.num_heads), float("-inf")
).to(device)
padded_encoding[0:num_nodes, 0:num_nodes] = dist_embedding
spatial_encoding.append(padded_encoding)
return th.stack(spatial_encoding, dim=0)
class SpatialEncoder3d(nn.Module):
r"""3D Spatial Encoder, as introduced in
`One Transformer Can Understand Both 2D & 3D Molecular Data
<https://arxiv.org/pdf/2210.01765.pdf>`__
This module encodes pair-wise relation between atom pair :math:`(i,j)` in
the 3D geometric space, according to the Gaussian Basis Kernel function:
:math:`\psi _{(i,j)} ^k = -\frac{1}{\sqrt{2\pi} \lvert \sigma^k \rvert}
\exp{\left ( -\frac{1}{2} \left( \frac{\gamma_{(i,j)} \lvert \lvert r_i -
r_j \rvert \rvert + \beta_{(i,j)} - \mu^k}{\lvert \sigma^k \rvert} \right)
^2 \right)},k=1,...,K,`
where :math:`K` is the number of Gaussian Basis kernels.
:math:`r_i` is the Cartesian coordinate of atom :math:`i`.
:math:`\gamma_{(i,j)}, \beta_{(i,j)}` are learnable scaling factors of
the Gaussian Basis kernels.
Parameters
----------
num_kernels : int
Number of Gaussian Basis Kernels to be applied.
Each Gaussian Basis Kernel contains a learnable kernel center
and a learnable scaling factor.
num_heads : int, optional
Number of attention heads if multi-head attention mechanism is applied.
Default : 1.
max_node_type : int, optional
Maximum number of node types. Default : 1.
Examples
--------
>>> import torch as th
>>> import dgl
>>> from dgl.nn import SpatialEncoder3d
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
>>> g = dgl.graph((u, v))
>>> coordinate = th.rand(4, 3)
>>> node_type = th.tensor([1, 0, 2, 1])
>>> spatial_encoder = SpatialEncoder3d(num_kernels=4,
... num_heads=8,
... max_node_type=3)
>>> out = spatial_encoder(g, coordinate, node_type=node_type)
>>> print(out.shape)
torch.Size([1, 4, 4, 8])
"""
def __init__(self, num_kernels, num_heads=1, max_node_type=1):
super().__init__()
self.num_kernels = num_kernels
self.num_heads = num_heads
self.max_node_type = max_node_type
self.gaussian_means = nn.Embedding(1, num_kernels)
self.gaussian_stds = nn.Embedding(1, num_kernels)
self.linear_layer_1 = nn.Linear(num_kernels, num_kernels)
self.linear_layer_2 = nn.Linear(num_kernels, num_heads)
if max_node_type == 1:
self.mul = nn.Embedding(1, 1)
self.bias = nn.Embedding(1, 1)
else:
self.mul = nn.Embedding(max_node_type + 1, 2)
self.bias = nn.Embedding(max_node_type + 1, 2)
nn.init.uniform_(self.gaussian_means.weight, 0, 3)
nn.init.uniform_(self.gaussian_stds.weight, 0, 3)
nn.init.constant_(self.mul.weight, 0)
nn.init.constant_(self.bias.weight, 1)
def forward(self, g, coord, node_type=None):
"""
Parameters
----------
g : DGLGraph
A DGLGraph to be encoded, which must be a homogeneous one.
coord : torch.Tensor
3D coordinates of nodes in :attr:`g`,
of shape :math:`(N, 3)`,
where :math:`N`: is the number of nodes in :attr:`g`.
node_type : torch.Tensor, optional
Node types of :attr:`g`. Default : None.
* If :attr:`max_node_type` is not 1, :attr:`node_type` needs to
be a tensor in shape :math:`(N,)`. The scaling factors of
each pair of nodes are determined by their node types.
* Otherwise, :attr:`node_type` should be None.
Returns
-------
torch.Tensor
Return attention bias as 3D spatial encoding of shape
:math:`(B, n, n, H)`, where :math:`B` is the batch size, :math:`n`
is the maximum number of nodes in unbatched graphs from :attr:`g`,
and :math:`H` is :attr:`num_heads`.
"""
device = g.device
g_list = unbatch(g)
max_num_nodes = th.max(g.batch_num_nodes())
spatial_encoding = []
sum_num_nodes = 0
if (self.max_node_type == 1) != (node_type is None):
raise ValueError(
"input node_type should be None if and only if "
"max_node_type is 1."
)
for ubg in g_list:
num_nodes = ubg.num_nodes()
sub_coord = coord[sum_num_nodes : sum_num_nodes + num_nodes]
# shape: [n, n], n = num_nodes
euc_dist = th.cdist(sub_coord, sub_coord, p=2)
if node_type is None:
# shape: [1]
mul = self.mul.weight[0, 0]
bias = self.bias.weight[0, 0]
else:
sub_node_type = node_type[
sum_num_nodes : sum_num_nodes + num_nodes
]
mul_embedding = self.mul(sub_node_type)
bias_embedding = self.bias(sub_node_type)
# shape: [n, n]
mul = mul_embedding[:, 0].unsqueeze(-1).repeat(
1, num_nodes
) + mul_embedding[:, 1].unsqueeze(0).repeat(num_nodes, 1)
bias = bias_embedding[:, 0].unsqueeze(-1).repeat(
1, num_nodes
) + bias_embedding[:, 1].unsqueeze(0).repeat(num_nodes, 1)
# shape: [n, n, k], k = num_kernels
scaled_dist = (
(mul * euc_dist + bias)
.repeat(self.num_kernels, 1, 1)
.permute((1, 2, 0))
)
# shape: [k]
gaussian_mean = self.gaussian_means.weight.float().view(-1)
gaussian_var = (
self.gaussian_stds.weight.float().view(-1).abs() + 1e-2
)
# shape: [n, n, k]
gaussian_kernel = (
(
-0.5
* (
th.div(
scaled_dist - gaussian_mean, gaussian_var
).square()
)
)
.exp()
.div(-math.sqrt(2 * math.pi) * gaussian_var)
)
encoding = self.linear_layer_1(gaussian_kernel)
encoding = F.gelu(encoding)
# [n, n, k] -> [n, n, a], a = num_heads
encoding = self.linear_layer_2(encoding)
# [n, n, a] -> [N, N, a], N = max_num_nodes, padded with -inf
padded_encoding = th.full(
(max_num_nodes, max_num_nodes, self.num_heads), float("-inf")
).to(device)
padded_encoding[0:num_nodes, 0:num_nodes] = encoding
spatial_encoding.append(padded_encoding)
sum_num_nodes += num_nodes
return th.stack(spatial_encoding, dim=0)
...@@ -1844,3 +1844,32 @@ def test_PathEncoder(max_len, feat_dim, num_heads): ...@@ -1844,3 +1844,32 @@ def test_PathEncoder(max_len, feat_dim, num_heads):
model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev) model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev)
bias = model(bg, edge_feat) bias = model(bg, edge_feat)
assert bias.shape == (2, 6, 6, num_heads) assert bias.shape == (2, 6, 6, num_heads)
@pytest.mark.parametrize('max_dist', [1, 4])
@pytest.mark.parametrize('num_kernels', [8, 16])
@pytest.mark.parametrize('num_heads', [1, 8])
def test_SpatialEncoder(max_dist, num_kernels, num_heads):
dev = F.ctx()
g1 = dgl.graph((
th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
)).to(dev)
g2 = dgl.graph((
th.tensor([0, 1, 2, 3, 2, 5]),
th.tensor([1, 2, 3, 4, 0, 3])
)).to(dev)
bg = dgl.batch([g1, g2])
ndata = th.rand(bg.num_nodes(), 3).to(dev)
num_nodes = bg.num_nodes()
node_type = th.randint(0, 512, (num_nodes,)).to(dev)
model_1 = nn.SpatialEncoder(max_dist, num_heads=num_heads).to(dev)
model_2 = nn.SpatialEncoder3d(num_kernels, num_heads=num_heads).to(dev)
model_3 = nn.SpatialEncoder3d(
num_kernels, num_heads=num_heads, max_node_type=512
).to(dev)
encoding = model_1(bg)
encoding3d_1 = model_2(bg, ndata)
encoding3d_2 = model_3(bg, ndata, node_type)
assert encoding.shape == (2, 6, 6, num_heads)
assert encoding3d_1.shape == (2, 6, 6, num_heads)
assert encoding3d_2.shape == (2, 6, 6, num_heads)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment