"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "b94004a6855ceac3bd68259f45186112def2fecf"
Unverified Commit b51fb6f6 authored by rudongyu's avatar rudongyu Committed by GitHub
Browse files

[NN] Refactor SpatialEncoder3d (#5894)

parent 2ef90be0
"""Spatial Encoder""" """Spatial Encoder"""
import math
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ....batch import unbatch
def gaussian(x, mean, std):
"""compute gaussian basis kernel function"""
const_pi = 3.14159
a = (2 * const_pi) ** 0.5
return th.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)
class SpatialEncoder(nn.Module): class SpatialEncoder(nn.Module):
...@@ -87,30 +90,32 @@ class SpatialEncoder3d(nn.Module): ...@@ -87,30 +90,32 @@ class SpatialEncoder3d(nn.Module):
`One Transformer Can Understand Both 2D & 3D Molecular Data `One Transformer Can Understand Both 2D & 3D Molecular Data
<https://arxiv.org/pdf/2210.01765.pdf>`__ <https://arxiv.org/pdf/2210.01765.pdf>`__
This module encodes pair-wise relation between atom pair :math:`(i,j)` in This module encodes pair-wise relation between node pair :math:`(i,j)` in
the 3D geometric space, according to the Gaussian Basis Kernel function: the 3D geometric space, according to the Gaussian Basis Kernel function:
:math:`\psi _{(i,j)} ^k = -\frac{1}{\sqrt{2\pi} \lvert \sigma^k \rvert} :math:`\psi _{(i,j)} ^k = \frac{1}{\sqrt{2\pi} \lvert \sigma^k \rvert}
\exp{\left ( -\frac{1}{2} \left( \frac{\gamma_{(i,j)} \lvert \lvert r_i - \exp{\left ( -\frac{1}{2} \left( \frac{\gamma_{(i,j)} \lvert \lvert r_i -
r_j \rvert \rvert + \beta_{(i,j)} - \mu^k}{\lvert \sigma^k \rvert} \right) r_j \rvert \rvert + \beta_{(i,j)} - \mu^k}{\lvert \sigma^k \rvert} \right)
^2 \right)},k=1,...,K,` ^2 \right)},k=1,...,K,`
where :math:`K` is the number of Gaussian Basis kernels. where :math:`K` is the number of Gaussian Basis kernels. :math:`r_i` is the
:math:`r_i` is the Cartesian coordinate of atom :math:`i`. Cartesian coordinate of node :math:`i`.
:math:`\gamma_{(i,j)}, \beta_{(i,j)}` are learnable scaling factors of :math:`\gamma_{(i,j)}, \beta_{(i,j)}` are learnable scaling factors and
the Gaussian Basis kernels. biases determined by node types. :math:`\mu^k, \sigma^k` are learnable
centers and standard deviations of the Gaussian Basis kernels.
Parameters Parameters
---------- ----------
num_kernels : int num_kernels : int
Number of Gaussian Basis Kernels to be applied. Number of Gaussian Basis Kernels to be applied. Each Gaussian Basis
Each Gaussian Basis Kernel contains a learnable kernel center Kernel contains a learnable kernel center and a learnable standard
and a learnable scaling factor. deviation.
num_heads : int, optional num_heads : int, optional
Number of attention heads if multi-head attention mechanism is applied. Number of attention heads if multi-head attention mechanism is applied.
Default : 1. Default : 1.
max_node_type : int, optional max_node_type : int, optional
Maximum number of node types. Default : 1. Maximum number of node types. Each node type has a corresponding
learnable scaling factor and a bias. Default : 100.
Examples Examples
-------- --------
...@@ -118,129 +123,87 @@ class SpatialEncoder3d(nn.Module): ...@@ -118,129 +123,87 @@ class SpatialEncoder3d(nn.Module):
>>> import dgl >>> import dgl
>>> from dgl.nn import SpatialEncoder3d >>> from dgl.nn import SpatialEncoder3d
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3]) >>> coordinate = th.rand(1, 4, 3)
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1]) >>> node_type = th.tensor([[1, 0, 2, 1]])
>>> g = dgl.graph((u, v))
>>> coordinate = th.rand(4, 3)
>>> node_type = th.tensor([1, 0, 2, 1])
>>> spatial_encoder = SpatialEncoder3d(num_kernels=4, >>> spatial_encoder = SpatialEncoder3d(num_kernels=4,
... num_heads=8, ... num_heads=8,
... max_node_type=3) ... max_node_type=3)
>>> out = spatial_encoder(g, coordinate, node_type=node_type) >>> out = spatial_encoder(coordinate, node_type=node_type)
>>> print(out.shape) >>> print(out.shape)
torch.Size([1, 4, 4, 8]) torch.Size([1, 4, 4, 8])
""" """
def __init__(self, num_kernels, num_heads=1, max_node_type=1): def __init__(self, num_kernels, num_heads=1, max_node_type=100):
super().__init__() super().__init__()
self.num_kernels = num_kernels self.num_kernels = num_kernels
self.num_heads = num_heads self.num_heads = num_heads
self.max_node_type = max_node_type self.max_node_type = max_node_type
self.gaussian_means = nn.Embedding(1, num_kernels) self.means = nn.Parameter(th.empty(num_kernels))
self.gaussian_stds = nn.Embedding(1, num_kernels) self.stds = nn.Parameter(th.empty(num_kernels))
self.linear_layer_1 = nn.Linear(num_kernels, num_kernels) self.linear_layer_1 = nn.Linear(num_kernels, num_kernels)
self.linear_layer_2 = nn.Linear(num_kernels, num_heads) self.linear_layer_2 = nn.Linear(num_kernels, num_heads)
if max_node_type == 1: # There are 2 * max_node_type + 3 pairs of gamma and beta parameters:
self.mul = nn.Embedding(1, 1) # 1. Parameters at position 0 are for default gamma/beta when no node
self.bias = nn.Embedding(1, 1) # type is given
else: # 2. Parameters at position 1 to max_node_type+1 are for src node types.
self.mul = nn.Embedding(max_node_type + 1, 2) # (position 1 is for padded unexisting nodes)
self.bias = nn.Embedding(max_node_type + 1, 2) # 3. Parameters at position max_node_type+2 to 2*max_node_type+2 are
nn.init.uniform_(self.gaussian_means.weight, 0, 3) # for tgt node types. (position max_node_type+2 is for padded)
nn.init.uniform_(self.gaussian_stds.weight, 0, 3) # unexisting nodes)
nn.init.constant_(self.mul.weight, 0) self.gamma = nn.Embedding(2 * max_node_type + 3, 1, padding_idx=0)
nn.init.constant_(self.bias.weight, 1) self.beta = nn.Embedding(2 * max_node_type + 3, 1, padding_idx=0)
def forward(self, g, coord, node_type=None): nn.init.uniform_(self.means, 0, 3)
nn.init.uniform_(self.stds, 0, 3)
nn.init.constant_(self.gamma.weight, 1)
nn.init.constant_(self.beta.weight, 0)
def forward(self, coord, node_type=None):
""" """
Parameters Parameters
---------- ----------
g : DGLGraph
A DGLGraph to be encoded, which must be a homogeneous one.
coord : torch.Tensor coord : torch.Tensor
3D coordinates of nodes in :attr:`g`, 3D coordinates of nodes in shape :math:`(B, N, 3)`, where :math:`B`
of shape :math:`(N, 3)`, is the batch size, :math:`N`: is the maximum number of nodes.
where :math:`N`: is the number of nodes in :attr:`g`.
node_type : torch.Tensor, optional node_type : torch.Tensor, optional
Node types of :attr:`g`. Default : None. Node type ids of nodes. Default : None.
* If :attr:`max_node_type` is not 1, :attr:`node_type` needs to * If specified, :attr:`node_type` should be a tensor in shape
be a tensor in shape :math:`(N,)`. The scaling factors of :math:`(B, N,)`. The scaling factors in gaussian kernels of each
each pair of nodes are determined by their node types. pair of nodes are determined by their node types.
* Otherwise, :attr:`node_type` should be None. * Otherwise, :attr:`node_type` will be set to zeros of the same
shape by default.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Return attention bias as 3D spatial encoding of shape Return attention bias as 3D spatial encoding of shape
:math:`(B, n, n, H)`, where :math:`B` is the batch size, :math:`n` :math:`(B, N, N, H)`, where :math:`H` is :attr:`num_heads`.
is the maximum number of nodes in unbatched graphs from :attr:`g`,
and :math:`H` is :attr:`num_heads`.
""" """
bsz, N = coord.shape[:2]
device = g.device euc_dist = th.cdist(coord, coord, p=2.0) # shape: [B, n, n]
g_list = unbatch(g) if node_type is None:
max_num_nodes = th.max(g.batch_num_nodes()) node_type = th.zeros([bsz, N, N, 2], device=coord.device).long()
spatial_encoding = th.zeros( else:
len(g_list), max_num_nodes, max_num_nodes, self.num_heads src_node_type = node_type.unsqueeze(-1).repeat(1, 1, N)
).to(device) tgt_node_type = node_type.unsqueeze(1).repeat(1, N, 1)
sum_num_nodes = 0 node_type = th.stack(
if (self.max_node_type == 1) != (node_type is None): [src_node_type + 2, tgt_node_type + self.max_node_type + 3],
raise ValueError( dim=-1,
"input node_type should be None if and only if " ) # shape: [B, n, n, 2]
"max_node_type is 1."
) # scaled euclidean distance
for i, ubg in enumerate(g_list): gamma = self.gamma(node_type).sum(dim=-2) # shape: [B, n, n, 1]
num_nodes = ubg.num_nodes() beta = self.beta(node_type).sum(dim=-2) # shape: [B, n, n, 1]
sub_coord = coord[sum_num_nodes : sum_num_nodes + num_nodes] euc_dist = gamma * euc_dist.unsqueeze(-1) + beta # shape: [B, n, n, 1]
# shape: [n, n], n = num_nodes # gaussian basis kernel
euc_dist = th.cdist(sub_coord, sub_coord, p=2) euc_dist = euc_dist.expand(-1, -1, -1, self.num_kernels)
if node_type is None: gaussian_kernel = gaussian(
# shape: [1] euc_dist, self.means, self.stds.abs() + 1e-2
mul = self.mul.weight[0, 0] ) # shape: [B, n, n, K]
bias = self.bias.weight[0, 0] # linear projection
else: encoding = self.linear_layer_1(gaussian_kernel)
sub_node_type = node_type[ encoding = F.gelu(encoding)
sum_num_nodes : sum_num_nodes + num_nodes encoding = self.linear_layer_2(encoding) # shape: [B, n, n, H]
]
mul_embedding = self.mul(sub_node_type) return encoding
bias_embedding = self.bias(sub_node_type)
# shape: [n, n]
mul = mul_embedding[:, 0].unsqueeze(-1).repeat(
1, num_nodes
) + mul_embedding[:, 1].unsqueeze(0).repeat(num_nodes, 1)
bias = bias_embedding[:, 0].unsqueeze(-1).repeat(
1, num_nodes
) + bias_embedding[:, 1].unsqueeze(0).repeat(num_nodes, 1)
# shape: [n, n, k], k = num_kernels
scaled_dist = (
(mul * euc_dist + bias)
.repeat(self.num_kernels, 1, 1)
.permute((1, 2, 0))
)
# shape: [k]
gaussian_mean = self.gaussian_means.weight.float().view(-1)
gaussian_var = (
self.gaussian_stds.weight.float().view(-1).abs() + 1e-2
)
# shape: [n, n, k]
gaussian_kernel = (
(
-0.5
* (
th.div(
scaled_dist - gaussian_mean, gaussian_var
).square()
)
)
.exp()
.div(-math.sqrt(2 * math.pi) * gaussian_var)
)
encoding = self.linear_layer_1(gaussian_kernel)
encoding = F.gelu(encoding)
# [n, n, k] -> [n, n, a], a = num_heads
encoding = self.linear_layer_2(encoding)
spatial_encoding[i, :num_nodes, :num_nodes] = encoding
sum_num_nodes += num_nodes
return spatial_encoding
...@@ -2547,10 +2547,21 @@ def test_PathEncoder(max_len, feat_dim, num_heads): ...@@ -2547,10 +2547,21 @@ def test_PathEncoder(max_len, feat_dim, num_heads):
@pytest.mark.parametrize("max_dist", [1, 4]) @pytest.mark.parametrize("max_dist", [1, 4])
@pytest.mark.parametrize("num_kernels", [8, 16]) @pytest.mark.parametrize("num_kernels", [4, 16])
@pytest.mark.parametrize("num_heads", [1, 8]) @pytest.mark.parametrize("num_heads", [1, 8])
def test_SpatialEncoder(max_dist, num_kernels, num_heads): def test_SpatialEncoder(max_dist, num_kernels, num_heads):
dev = F.ctx() dev = F.ctx()
# single graph encoding 3d
num_nodes = 4
coord = th.rand(1, num_nodes, 3).to(dev)
node_type = th.tensor([[1, 0, 2, 1]]).to(dev)
spatial_encoder = nn.SpatialEncoder3d(
num_kernels=num_kernels, num_heads=num_heads, max_node_type=3
).to(dev)
out = spatial_encoder(coord, node_type=node_type)
assert out.shape == (1, num_nodes, num_nodes, num_heads)
# encoding on a batch of graphs
g1 = dgl.graph( g1 = dgl.graph(
( (
th.tensor([0, 0, 0, 1, 1, 2, 3, 3]), th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
...@@ -2560,21 +2571,29 @@ def test_SpatialEncoder(max_dist, num_kernels, num_heads): ...@@ -2560,21 +2571,29 @@ def test_SpatialEncoder(max_dist, num_kernels, num_heads):
g2 = dgl.graph( g2 = dgl.graph(
(th.tensor([0, 1, 2, 3, 2, 5]), th.tensor([1, 2, 3, 4, 0, 3])) (th.tensor([0, 1, 2, 3, 2, 5]), th.tensor([1, 2, 3, 4, 0, 3]))
).to(dev) ).to(dev)
bg = dgl.batch([g1, g2]) bsz, max_num_nodes = 2, 6
ndata = th.rand(bg.num_nodes(), 3).to(dev) # 2d encoding
num_nodes = bg.num_nodes() dist = -th.ones((bsz, max_num_nodes, max_num_nodes), dtype=th.long).to(dev)
node_type = th.randint(0, 512, (num_nodes,)).to(dev)
dist = -th.ones((2, 6, 6), dtype=th.long).to(dev)
dist[0, :4, :4] = shortest_dist(g1, root=None, return_paths=False) dist[0, :4, :4] = shortest_dist(g1, root=None, return_paths=False)
dist[1, :6, :6] = shortest_dist(g2, root=None, return_paths=False) dist[1, :6, :6] = shortest_dist(g2, root=None, return_paths=False)
model_1 = nn.SpatialEncoder(max_dist, num_heads=num_heads).to(dev) model_1 = nn.SpatialEncoder(max_dist, num_heads=num_heads).to(dev)
encoding = model_1(dist)
assert encoding.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)
# 3d encoding
coord = th.rand(bsz, max_num_nodes, 3).to(dev)
node_type = th.randint(
0,
512,
(
bsz,
max_num_nodes,
),
).to(dev)
model_2 = nn.SpatialEncoder3d(num_kernels, num_heads=num_heads).to(dev) model_2 = nn.SpatialEncoder3d(num_kernels, num_heads=num_heads).to(dev)
model_3 = nn.SpatialEncoder3d( model_3 = nn.SpatialEncoder3d(
num_kernels, num_heads=num_heads, max_node_type=512 num_kernels, num_heads=num_heads, max_node_type=512
).to(dev) ).to(dev)
encoding = model_1(dist) encoding3d_1 = model_2(coord)
encoding3d_1 = model_2(bg, ndata) encoding3d_2 = model_3(coord, node_type)
encoding3d_2 = model_3(bg, ndata, node_type) assert encoding3d_1.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)
assert encoding.shape == (2, 6, 6, num_heads) assert encoding3d_2.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)
assert encoding3d_1.shape == (2, 6, 6, num_heads)
assert encoding3d_2.shape == (2, 6, 6, num_heads)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment