Unverified Commit e47a0279 authored by Zhiteng Li's avatar Zhiteng Li Committed by GitHub
Browse files

[NN] Refactor DegreeEncoder, SpatialEncoder and PathEncoder (#5799)


Co-authored-by: default avatarrudongyu <ru_dongyu@outlook.com>
parent 4fd0a158
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
from ....base import DGLError
class DegreeEncoder(nn.Module): class DegreeEncoder(nn.Module):
r"""Degree Encoder, as introduced in r"""Degree Encoder, as introduced in
...@@ -31,10 +29,19 @@ class DegreeEncoder(nn.Module): ...@@ -31,10 +29,19 @@ class DegreeEncoder(nn.Module):
------- -------
>>> import dgl >>> import dgl
>>> from dgl.nn import DegreeEncoder >>> from dgl.nn import DegreeEncoder
>>> import torch as th
>>> from torch.nn.utils.rnn import pad_sequence
>>> g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1])) >>> g1 = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> g2 = dgl.graph(([0,1], [1,0]))
>>> in_degree = pad_sequence([g1.in_degrees(), g2.in_degrees()], batch_first=True)
>>> out_degree = pad_sequence([g1.out_degrees(), g2.out_degrees()], batch_first=True)
>>> print(in_degree.shape)
torch.Size([2, 4])
>>> degree_encoder = DegreeEncoder(5, 16) >>> degree_encoder = DegreeEncoder(5, 16)
>>> degree_embedding = degree_encoder(g) >>> degree_embedding = degree_encoder(th.stack((in_degree, out_degree)))
>>> print(degree_embedding.shape)
torch.Size([2, 4, 16])
""" """
def __init__(self, max_degree, embedding_dim, direction="both"): def __init__(self, max_degree, embedding_dim, direction="both"):
...@@ -53,36 +60,35 @@ class DegreeEncoder(nn.Module): ...@@ -53,36 +60,35 @@ class DegreeEncoder(nn.Module):
) )
self.max_degree = max_degree self.max_degree = max_degree
def forward(self, g): def forward(self, degrees):
""" """
Parameters Parameters
---------- ----------
g : DGLGraph degrees : Tensor
A DGLGraph to be encoded. Graphs with more than one type of edges If :attr:`direction` is ``both``, it should be stacked in degrees and out degrees
are not allowed. of the batched graph with zero padding, a tensor of shape :math:`(2, B, N)`.
Otherwise, it should be zero-padded in degrees or out degrees of the batched
graph, a tensor of shape :math:`(B, N)`, where :math:`B` is the batch size
of the batched graph, and :math:`N` is the maximum number of nodes.
Returns Returns
------- -------
Tensor Tensor
Return degree embedding vectors of shape :math:`(N, d)`, Return degree embedding vectors of shape :math:`(B, N, d)`,
where :math:`N` is the number of nodes in the input graph and where :math:`d` is :attr:`embedding_dim`.
:math:`d` is :attr:`embedding_dim`.
""" """
if len(g.etypes) > 1: degrees = th.clamp(degrees, min=0, max=self.max_degree)
raise DGLError(
"The input graph should have no more than one type of edges."
)
in_degree = th.clamp(g.in_degrees(), min=0, max=self.max_degree)
out_degree = th.clamp(g.out_degrees(), min=0, max=self.max_degree)
if self.direction == "in": if self.direction == "in":
degree_embedding = self.encoder(in_degree) assert len(degrees.shape) == 2
degree_embedding = self.encoder(degrees)
elif self.direction == "out": elif self.direction == "out":
degree_embedding = self.encoder(out_degree) assert len(degrees.shape) == 2
degree_embedding = self.encoder(degrees)
elif self.direction == "both": elif self.direction == "both":
degree_embedding = self.encoder1(in_degree) + self.encoder2( assert len(degrees.shape) == 3 and degrees.shape[0] == 2
out_degree degree_embedding = self.encoder1(degrees[0]) + self.encoder2(
degrees[1]
) )
else: else:
raise ValueError( raise ValueError(
......
...@@ -2,9 +2,6 @@ ...@@ -2,9 +2,6 @@
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
from ....batch import unbatch
from ....transforms import shortest_dist
class PathEncoder(nn.Module): class PathEncoder(nn.Module):
r"""Path Encoder, as introduced in Edge Encoding of r"""Path Encoder, as introduced in Edge Encoding of
...@@ -31,13 +28,21 @@ class PathEncoder(nn.Module): ...@@ -31,13 +28,21 @@ class PathEncoder(nn.Module):
>>> import torch as th >>> import torch as th
>>> import dgl >>> import dgl
>>> from dgl.nn import PathEncoder >>> from dgl.nn import PathEncoder
>>> from dgl import shortest_dist
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3]) >>> g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
>>> g = dgl.graph((u, v))
>>> edata = th.rand(8, 16) >>> edata = th.rand(8, 16)
>>> # Since shortest_dist returns -1 for unreachable node pairs,
>>> # edata[-1] should be filled with zero padding.
>>> edata = th.cat(
(edata, th.zeros(1, 16)), dim=0
)
>>> dist, path = shortest_dist(g, root=None, return_paths=True)
>>> path_data = edata[path[:, :, :2]]
>>> path_encoder = PathEncoder(2, 16, num_heads=8) >>> path_encoder = PathEncoder(2, 16, num_heads=8)
>>> out = path_encoder(g, edata) >>> out = path_encoder(dist.unsqueeze(0), path_data.unsqueeze(0))
>>> print(out.shape)
torch.Size([1, 4, 4, 8])
""" """
def __init__(self, max_len, feat_dim, num_heads=1): def __init__(self, max_len, feat_dim, num_heads=1):
...@@ -47,16 +52,18 @@ class PathEncoder(nn.Module): ...@@ -47,16 +52,18 @@ class PathEncoder(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
self.embedding_table = nn.Embedding(max_len * num_heads, feat_dim) self.embedding_table = nn.Embedding(max_len * num_heads, feat_dim)
def forward(self, g, edge_feat): def forward(self, dist, path_data):
""" """
Parameters Parameters
---------- ----------
g : DGLGraph dist : Tensor
A DGLGraph to be encoded, which must be a homogeneous one. Shortest path distance matrix of the batched graph with zero padding,
edge_feat : torch.Tensor of shape :math:`(B, N, N)`, where :math:`B` is the batch size of
The input edge feature of shape :math:`(E, d)`, the batched graph, and :math:`N` is the maximum number of nodes.
where :math:`E` is the number of edges in the input graph and path_data : Tensor
:math:`d` is :attr:`feat_dim`. Edge feature along the shortest path with zero padding, of shape
:math:`(B, N, N, L, d)`, where :math:`L` is the maximum length of
the shortest paths, and :math:`d` is :attr:`feat_dim`.
Returns Returns
------- -------
...@@ -66,40 +73,14 @@ class PathEncoder(nn.Module): ...@@ -66,40 +73,14 @@ class PathEncoder(nn.Module):
the input graph, :math:`N` is the maximum number of nodes, and the input graph, :math:`N` is the maximum number of nodes, and
:math:`H` is :attr:`num_heads`. :math:`H` is :attr:`num_heads`.
""" """
device = g.device shortest_distance = th.clamp(dist, min=1, max=self.max_len)
g_list = unbatch(g) edge_embedding = self.embedding_table.weight.reshape(
sum_num_edges = 0 self.max_len, self.num_heads, -1
max_num_nodes = th.max(g.batch_num_nodes())
path_encoding = th.zeros(
len(g_list), max_num_nodes, max_num_nodes, self.num_heads
).to(device)
for i, ubg in enumerate(g_list):
num_nodes = ubg.num_nodes()
num_edges = ubg.num_edges()
edata = edge_feat[sum_num_edges : (sum_num_edges + num_edges)]
sum_num_edges = sum_num_edges + num_edges
edata = th.cat(
(edata, th.zeros(1, self.feat_dim).to(edata.device)), dim=0
) )
dist, path = shortest_dist(ubg, root=None, return_paths=True) path_encoding = th.div(
path_len = max(1, min(self.max_len, path.size(dim=2))) th.einsum("bxyld,lhd->bxyh", path_data, edge_embedding).permute(
3, 0, 1, 2
# shape: [n, n, l], n = num_nodes, l = path_len
shortest_path = path[:, :, 0:path_len]
# shape: [n, n]
shortest_distance = th.clamp(dist, min=1, max=path_len)
# shape: [n, n, l, d], d = feat_dim
path_data = edata[shortest_path]
# shape: [l, h, d]
edge_embedding = self.embedding_table.weight[
0 : path_len * self.num_heads
].reshape(path_len, self.num_heads, -1)
# [n, n, l, d] einsum [l, h, d] -> [n, n, h]
path_encoding[i, :num_nodes, :num_nodes] = th.div(
th.einsum("xyld,lhd->xyh", path_data, edge_embedding).permute(
2, 0, 1
), ),
shortest_distance, shortest_distance,
).permute(1, 2, 0) ).permute(1, 2, 3, 0)
return path_encoding return path_encoding
...@@ -7,7 +7,6 @@ import torch.nn as nn ...@@ -7,7 +7,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ....batch import unbatch from ....batch import unbatch
from ....transforms import shortest_dist
class SpatialEncoder(nn.Module): class SpatialEncoder(nn.Module):
...@@ -33,14 +32,19 @@ class SpatialEncoder(nn.Module): ...@@ -33,14 +32,19 @@ class SpatialEncoder(nn.Module):
>>> import torch as th >>> import torch as th
>>> import dgl >>> import dgl
>>> from dgl.nn import SpatialEncoder >>> from dgl.nn import SpatialEncoder
>>> from dgl import shortest_dist
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1]) >>> g1 = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> g = dgl.graph((u, v)) >>> g2 = dgl.graph(([0,1], [1,0]))
>>> n1, n2 = g1.num_nodes(), g2.num_nodes()
>>> # use -1 padding since shortest_dist returns -1 for unreachable node pairs
>>> dist = -th.ones((2, 4, 4), dtype=th.long)
>>> dist[0, :n1, :n1] = shortest_dist(g1, root=None, return_paths=False)
>>> dist[1, :n2, :n2] = shortest_dist(g2, root=None, return_paths=False)
>>> spatial_encoder = SpatialEncoder(max_dist=2, num_heads=8) >>> spatial_encoder = SpatialEncoder(max_dist=2, num_heads=8)
>>> out = spatial_encoder(g) >>> out = spatial_encoder(dist)
>>> print(out.shape) >>> print(out.shape)
torch.Size([1, 4, 4, 8]) torch.Size([2, 4, 4, 8])
""" """
def __init__(self, max_dist, num_heads=1): def __init__(self, max_dist, num_heads=1):
...@@ -52,41 +56,29 @@ class SpatialEncoder(nn.Module): ...@@ -52,41 +56,29 @@ class SpatialEncoder(nn.Module):
max_dist + 2, num_heads, padding_idx=0 max_dist + 2, num_heads, padding_idx=0
) )
def forward(self, g): def forward(self, dist):
""" """
Parameters Parameters
---------- ----------
g : DGLGraph dist : Tensor
A DGLGraph to be encoded, which must be a homogeneous one. Shortest path distance of the batched graph with -1 padding, a tensor
of shape :math:`(B, N, N)`, where :math:`B` is the batch size of
the batched graph, and :math:`N` is the maximum number of nodes.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Return attention bias as spatial encoding of shape Return attention bias as spatial encoding of shape
:math:`(B, N, N, H)`, where :math:`N` is the maximum number of :math:`(B, N, N, H)`, where :math:`H` is :attr:`num_heads`.
nodes, :math:`B` is the batch size of the input graph, and
:math:`H` is :attr:`num_heads`.
""" """
device = g.device spatial_encoding = self.embedding_table(
g_list = unbatch(g)
max_num_nodes = th.max(g.batch_num_nodes())
spatial_encoding = th.zeros(
len(g_list), max_num_nodes, max_num_nodes, self.num_heads
).to(device)
for i, ubg in enumerate(g_list):
num_nodes = ubg.num_nodes()
dist = (
th.clamp( th.clamp(
shortest_dist(ubg, root=None, return_paths=False), dist,
min=-1, min=-1,
max=self.max_dist, max=self.max_dist,
) )
+ 1 + 1
) )
# shape: [n, n, h], n = num_nodes, h = num_heads
dist_embedding = self.embedding_table(dist)
spatial_encoding[i, :num_nodes, :num_nodes] = dist_embedding
return spatial_encoding return spatial_encoding
......
...@@ -12,6 +12,8 @@ import pytest ...@@ -12,6 +12,8 @@ import pytest
import scipy as sp import scipy as sp
import torch import torch
import torch as th import torch as th
from dgl import shortest_dist
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Adam, SparseAdam from torch.optim import Adam, SparseAdam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from utils import parametrize_idtype from utils import parametrize_idtype
...@@ -2389,15 +2391,32 @@ def test_DeepWalk(): ...@@ -2389,15 +2391,32 @@ def test_DeepWalk():
@pytest.mark.parametrize("embedding_dim", [8, 16]) @pytest.mark.parametrize("embedding_dim", [8, 16])
@pytest.mark.parametrize("direction", ["in", "out", "both"]) @pytest.mark.parametrize("direction", ["in", "out", "both"])
def test_degree_encoder(max_degree, embedding_dim, direction): def test_degree_encoder(max_degree, embedding_dim, direction):
g = dgl.graph( g1 = dgl.graph(
( (
th.tensor([0, 0, 0, 1, 1, 2, 3, 3]), th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
th.tensor([1, 2, 3, 0, 3, 0, 0, 1]), th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
) )
) )
g2 = dgl.graph(
(
th.tensor([0, 1]),
th.tensor([1, 0]),
)
)
in_degree = pad_sequence(
[g1.in_degrees(), g2.in_degrees()], batch_first=True
)
out_degree = pad_sequence(
[g1.out_degrees(), g2.out_degrees()], batch_first=True
)
model = nn.DegreeEncoder(max_degree, embedding_dim, direction=direction) model = nn.DegreeEncoder(max_degree, embedding_dim, direction=direction)
de_g = model(g) if direction == "in":
assert de_g.shape == (4, embedding_dim) de_g = model(in_degree)
elif direction == "out":
de_g = model(out_degree)
elif direction == "both":
de_g = model(th.stack((in_degree, out_degree)))
assert de_g.shape == (2, 4, embedding_dim)
@parametrize_idtype @parametrize_idtype
...@@ -2498,25 +2517,24 @@ def test_GraphormerLayer(attn_bias_type, norm_first): ...@@ -2498,25 +2517,24 @@ def test_GraphormerLayer(attn_bias_type, norm_first):
assert out.shape == (batch_size, num_nodes, feat_size) assert out.shape == (batch_size, num_nodes, feat_size)
@pytest.mark.parametrize("max_len", [1, 4]) @pytest.mark.parametrize("max_len", [1, 2])
@pytest.mark.parametrize("feat_dim", [16]) @pytest.mark.parametrize("feat_dim", [16])
@pytest.mark.parametrize("num_heads", [1, 8]) @pytest.mark.parametrize("num_heads", [1, 8])
def test_PathEncoder(max_len, feat_dim, num_heads): def test_PathEncoder(max_len, feat_dim, num_heads):
dev = F.ctx() dev = F.ctx()
g1 = dgl.graph( g = dgl.graph(
( (
th.tensor([0, 0, 0, 1, 1, 2, 3, 3]), th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
th.tensor([1, 2, 3, 0, 3, 0, 0, 1]), th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
) )
).to(dev) ).to(dev)
g2 = dgl.graph( edge_feat = th.rand(g.num_edges(), feat_dim).to(dev)
(th.tensor([0, 1, 2, 3, 2, 5]), th.tensor([1, 2, 3, 4, 0, 3])) edge_feat = th.cat((edge_feat, th.zeros(1, 16).to(dev)), dim=0)
).to(dev) dist, path = shortest_dist(g, root=None, return_paths=True)
bg = dgl.batch([g1, g2]) path_data = edge_feat[path[:, :, :max_len]]
edge_feat = th.rand(bg.num_edges(), feat_dim).to(dev)
model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev) model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev)
bias = model(bg, edge_feat) bias = model(dist.unsqueeze(0), path_data.unsqueeze(0))
assert bias.shape == (2, 6, 6, num_heads) assert bias.shape == (1, 4, 4, num_heads)
@pytest.mark.parametrize("max_dist", [1, 4]) @pytest.mark.parametrize("max_dist", [1, 4])
...@@ -2537,12 +2555,15 @@ def test_SpatialEncoder(max_dist, num_kernels, num_heads): ...@@ -2537,12 +2555,15 @@ def test_SpatialEncoder(max_dist, num_kernels, num_heads):
ndata = th.rand(bg.num_nodes(), 3).to(dev) ndata = th.rand(bg.num_nodes(), 3).to(dev)
num_nodes = bg.num_nodes() num_nodes = bg.num_nodes()
node_type = th.randint(0, 512, (num_nodes,)).to(dev) node_type = th.randint(0, 512, (num_nodes,)).to(dev)
dist = -th.ones((2, 6, 6), dtype=th.long).to(dev)
dist[0, :4, :4] = shortest_dist(g1, root=None, return_paths=False)
dist[1, :6, :6] = shortest_dist(g2, root=None, return_paths=False)
model_1 = nn.SpatialEncoder(max_dist, num_heads=num_heads).to(dev) model_1 = nn.SpatialEncoder(max_dist, num_heads=num_heads).to(dev)
model_2 = nn.SpatialEncoder3d(num_kernels, num_heads=num_heads).to(dev) model_2 = nn.SpatialEncoder3d(num_kernels, num_heads=num_heads).to(dev)
model_3 = nn.SpatialEncoder3d( model_3 = nn.SpatialEncoder3d(
num_kernels, num_heads=num_heads, max_node_type=512 num_kernels, num_heads=num_heads, max_node_type=512
).to(dev) ).to(dev)
encoding = model_1(bg) encoding = model_1(dist)
encoding3d_1 = model_2(bg, ndata) encoding3d_1 = model_2(bg, ndata)
encoding3d_2 = model_3(bg, ndata, node_type) encoding3d_2 = model_3(bg, ndata, node_type)
assert encoding.shape == (2, 6, 6, num_heads) assert encoding.shape == (2, 6, 6, num_heads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment