Unverified Commit e47a0279 authored by Zhiteng Li's avatar Zhiteng Li Committed by GitHub
Browse files

[NN] Refactor DegreeEncoder, SpatialEncoder and PathEncoder (#5799)


Co-authored-by: default avatarrudongyu <ru_dongyu@outlook.com>
parent 4fd0a158
......@@ -3,8 +3,6 @@
import torch as th
import torch.nn as nn
from ....base import DGLError
class DegreeEncoder(nn.Module):
r"""Degree Encoder, as introduced in
......@@ -31,10 +29,19 @@ class DegreeEncoder(nn.Module):
-------
>>> import dgl
>>> from dgl.nn import DegreeEncoder
>>> import torch as th
>>> from torch.nn.utils.rnn import pad_sequence
>>> g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> g1 = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> g2 = dgl.graph(([0,1], [1,0]))
>>> in_degree = pad_sequence([g1.in_degrees(), g2.in_degrees()], batch_first=True)
>>> out_degree = pad_sequence([g1.out_degrees(), g2.out_degrees()], batch_first=True)
>>> print(in_degree.shape)
torch.Size([2, 4])
>>> degree_encoder = DegreeEncoder(5, 16)
>>> degree_embedding = degree_encoder(g)
>>> degree_embedding = degree_encoder(th.stack((in_degree, out_degree)))
>>> print(degree_embedding.shape)
torch.Size([2, 4, 16])
"""
def __init__(self, max_degree, embedding_dim, direction="both"):
......@@ -53,36 +60,35 @@ class DegreeEncoder(nn.Module):
)
self.max_degree = max_degree
def forward(self, g):
def forward(self, degrees):
"""
Parameters
----------
g : DGLGraph
A DGLGraph to be encoded. Graphs with more than one type of edges
are not allowed.
degrees : Tensor
If :attr:`direction` is ``both``, it should be stacked in degrees and out degrees
of the batched graph with zero padding, a tensor of shape :math:`(2, B, N)`.
Otherwise, it should be zero-padded in degrees or out degrees of the batched
graph, a tensor of shape :math:`(B, N)`, where :math:`B` is the batch size
of the batched graph, and :math:`N` is the maximum number of nodes.
Returns
-------
Tensor
Return degree embedding vectors of shape :math:`(N, d)`,
where :math:`N` is the number of nodes in the input graph and
:math:`d` is :attr:`embedding_dim`.
Return degree embedding vectors of shape :math:`(B, N, d)`,
where :math:`d` is :attr:`embedding_dim`.
"""
if len(g.etypes) > 1:
raise DGLError(
"The input graph should have no more than one type of edges."
)
in_degree = th.clamp(g.in_degrees(), min=0, max=self.max_degree)
out_degree = th.clamp(g.out_degrees(), min=0, max=self.max_degree)
degrees = th.clamp(degrees, min=0, max=self.max_degree)
if self.direction == "in":
degree_embedding = self.encoder(in_degree)
assert len(degrees.shape) == 2
degree_embedding = self.encoder(degrees)
elif self.direction == "out":
degree_embedding = self.encoder(out_degree)
assert len(degrees.shape) == 2
degree_embedding = self.encoder(degrees)
elif self.direction == "both":
degree_embedding = self.encoder1(in_degree) + self.encoder2(
out_degree
assert len(degrees.shape) == 3 and degrees.shape[0] == 2
degree_embedding = self.encoder1(degrees[0]) + self.encoder2(
degrees[1]
)
else:
raise ValueError(
......
......@@ -2,9 +2,6 @@
import torch as th
import torch.nn as nn
from ....batch import unbatch
from ....transforms import shortest_dist
class PathEncoder(nn.Module):
r"""Path Encoder, as introduced in Edge Encoding of
......@@ -31,13 +28,21 @@ class PathEncoder(nn.Module):
>>> import torch as th
>>> import dgl
>>> from dgl.nn import PathEncoder
>>> from dgl import shortest_dist
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
>>> g = dgl.graph((u, v))
>>> g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> edata = th.rand(8, 16)
>>> # Since shortest_dist returns -1 for unreachable node pairs,
>>> # edata[-1] should be filled with zero padding.
>>> edata = th.cat(
(edata, th.zeros(1, 16)), dim=0
)
>>> dist, path = shortest_dist(g, root=None, return_paths=True)
>>> path_data = edata[path[:, :, :2]]
>>> path_encoder = PathEncoder(2, 16, num_heads=8)
>>> out = path_encoder(g, edata)
>>> out = path_encoder(dist.unsqueeze(0), path_data.unsqueeze(0))
>>> print(out.shape)
torch.Size([1, 4, 4, 8])
"""
def __init__(self, max_len, feat_dim, num_heads=1):
......@@ -47,16 +52,18 @@ class PathEncoder(nn.Module):
self.num_heads = num_heads
self.embedding_table = nn.Embedding(max_len * num_heads, feat_dim)
def forward(self, g, edge_feat):
def forward(self, dist, path_data):
"""
Parameters
----------
g : DGLGraph
A DGLGraph to be encoded, which must be a homogeneous one.
edge_feat : torch.Tensor
The input edge feature of shape :math:`(E, d)`,
where :math:`E` is the number of edges in the input graph and
:math:`d` is :attr:`feat_dim`.
dist : Tensor
Shortest path distance matrix of the batched graph with zero padding,
of shape :math:`(B, N, N)`, where :math:`B` is the batch size of
the batched graph, and :math:`N` is the maximum number of nodes.
path_data : Tensor
Edge feature along the shortest path with zero padding, of shape
:math:`(B, N, N, L, d)`, where :math:`L` is the maximum length of
the shortest paths, and :math:`d` is :attr:`feat_dim`.
Returns
-------
......@@ -66,40 +73,14 @@ class PathEncoder(nn.Module):
the input graph, :math:`N` is the maximum number of nodes, and
:math:`H` is :attr:`num_heads`.
"""
device = g.device
g_list = unbatch(g)
sum_num_edges = 0
max_num_nodes = th.max(g.batch_num_nodes())
path_encoding = th.zeros(
len(g_list), max_num_nodes, max_num_nodes, self.num_heads
).to(device)
for i, ubg in enumerate(g_list):
num_nodes = ubg.num_nodes()
num_edges = ubg.num_edges()
edata = edge_feat[sum_num_edges : (sum_num_edges + num_edges)]
sum_num_edges = sum_num_edges + num_edges
edata = th.cat(
(edata, th.zeros(1, self.feat_dim).to(edata.device)), dim=0
shortest_distance = th.clamp(dist, min=1, max=self.max_len)
edge_embedding = self.embedding_table.weight.reshape(
self.max_len, self.num_heads, -1
)
dist, path = shortest_dist(ubg, root=None, return_paths=True)
path_len = max(1, min(self.max_len, path.size(dim=2)))
# shape: [n, n, l], n = num_nodes, l = path_len
shortest_path = path[:, :, 0:path_len]
# shape: [n, n]
shortest_distance = th.clamp(dist, min=1, max=path_len)
# shape: [n, n, l, d], d = feat_dim
path_data = edata[shortest_path]
# shape: [l, h, d]
edge_embedding = self.embedding_table.weight[
0 : path_len * self.num_heads
].reshape(path_len, self.num_heads, -1)
# [n, n, l, d] einsum [l, h, d] -> [n, n, h]
path_encoding[i, :num_nodes, :num_nodes] = th.div(
th.einsum("xyld,lhd->xyh", path_data, edge_embedding).permute(
2, 0, 1
path_encoding = th.div(
th.einsum("bxyld,lhd->bxyh", path_data, edge_embedding).permute(
3, 0, 1, 2
),
shortest_distance,
).permute(1, 2, 0)
).permute(1, 2, 3, 0)
return path_encoding
......@@ -7,7 +7,6 @@ import torch.nn as nn
import torch.nn.functional as F
from ....batch import unbatch
from ....transforms import shortest_dist
class SpatialEncoder(nn.Module):
......@@ -33,14 +32,19 @@ class SpatialEncoder(nn.Module):
>>> import torch as th
>>> import dgl
>>> from dgl.nn import SpatialEncoder
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
>>> g = dgl.graph((u, v))
>>> from dgl import shortest_dist
>>> g1 = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> g2 = dgl.graph(([0,1], [1,0]))
>>> n1, n2 = g1.num_nodes(), g2.num_nodes()
>>> # use -1 padding since shortest_dist returns -1 for unreachable node pairs
>>> dist = -th.ones((2, 4, 4), dtype=th.long)
>>> dist[0, :n1, :n1] = shortest_dist(g1, root=None, return_paths=False)
>>> dist[1, :n2, :n2] = shortest_dist(g2, root=None, return_paths=False)
>>> spatial_encoder = SpatialEncoder(max_dist=2, num_heads=8)
>>> out = spatial_encoder(g)
>>> out = spatial_encoder(dist)
>>> print(out.shape)
torch.Size([1, 4, 4, 8])
torch.Size([2, 4, 4, 8])
"""
def __init__(self, max_dist, num_heads=1):
......@@ -52,41 +56,29 @@ class SpatialEncoder(nn.Module):
max_dist + 2, num_heads, padding_idx=0
)
def forward(self, g):
def forward(self, dist):
"""
Parameters
----------
g : DGLGraph
A DGLGraph to be encoded, which must be a homogeneous one.
dist : Tensor
Shortest path distance of the batched graph with -1 padding, a tensor
of shape :math:`(B, N, N)`, where :math:`B` is the batch size of
the batched graph, and :math:`N` is the maximum number of nodes.
Returns
-------
torch.Tensor
Return attention bias as spatial encoding of shape
:math:`(B, N, N, H)`, where :math:`N` is the maximum number of
nodes, :math:`B` is the batch size of the input graph, and
:math:`H` is :attr:`num_heads`.
:math:`(B, N, N, H)`, where :math:`H` is :attr:`num_heads`.
"""
device = g.device
g_list = unbatch(g)
max_num_nodes = th.max(g.batch_num_nodes())
spatial_encoding = th.zeros(
len(g_list), max_num_nodes, max_num_nodes, self.num_heads
).to(device)
for i, ubg in enumerate(g_list):
num_nodes = ubg.num_nodes()
dist = (
spatial_encoding = self.embedding_table(
th.clamp(
shortest_dist(ubg, root=None, return_paths=False),
dist,
min=-1,
max=self.max_dist,
)
+ 1
)
# shape: [n, n, h], n = num_nodes, h = num_heads
dist_embedding = self.embedding_table(dist)
spatial_encoding[i, :num_nodes, :num_nodes] = dist_embedding
return spatial_encoding
......
......@@ -12,6 +12,8 @@ import pytest
import scipy as sp
import torch
import torch as th
from dgl import shortest_dist
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Adam, SparseAdam
from torch.utils.data import DataLoader
from utils import parametrize_idtype
......@@ -2389,15 +2391,32 @@ def test_DeepWalk():
@pytest.mark.parametrize("embedding_dim", [8, 16])
@pytest.mark.parametrize("direction", ["in", "out", "both"])
def test_degree_encoder(max_degree, embedding_dim, direction):
g = dgl.graph(
g1 = dgl.graph(
(
th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
)
)
g2 = dgl.graph(
(
th.tensor([0, 1]),
th.tensor([1, 0]),
)
)
in_degree = pad_sequence(
[g1.in_degrees(), g2.in_degrees()], batch_first=True
)
out_degree = pad_sequence(
[g1.out_degrees(), g2.out_degrees()], batch_first=True
)
model = nn.DegreeEncoder(max_degree, embedding_dim, direction=direction)
de_g = model(g)
assert de_g.shape == (4, embedding_dim)
if direction == "in":
de_g = model(in_degree)
elif direction == "out":
de_g = model(out_degree)
elif direction == "both":
de_g = model(th.stack((in_degree, out_degree)))
assert de_g.shape == (2, 4, embedding_dim)
@parametrize_idtype
......@@ -2498,25 +2517,24 @@ def test_GraphormerLayer(attn_bias_type, norm_first):
assert out.shape == (batch_size, num_nodes, feat_size)
@pytest.mark.parametrize("max_len", [1, 4])
@pytest.mark.parametrize("max_len", [1, 2])
@pytest.mark.parametrize("feat_dim", [16])
@pytest.mark.parametrize("num_heads", [1, 8])
def test_PathEncoder(max_len, feat_dim, num_heads):
dev = F.ctx()
g1 = dgl.graph(
g = dgl.graph(
(
th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
)
).to(dev)
g2 = dgl.graph(
(th.tensor([0, 1, 2, 3, 2, 5]), th.tensor([1, 2, 3, 4, 0, 3]))
).to(dev)
bg = dgl.batch([g1, g2])
edge_feat = th.rand(bg.num_edges(), feat_dim).to(dev)
edge_feat = th.rand(g.num_edges(), feat_dim).to(dev)
edge_feat = th.cat((edge_feat, th.zeros(1, 16).to(dev)), dim=0)
dist, path = shortest_dist(g, root=None, return_paths=True)
path_data = edge_feat[path[:, :, :max_len]]
model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev)
bias = model(bg, edge_feat)
assert bias.shape == (2, 6, 6, num_heads)
bias = model(dist.unsqueeze(0), path_data.unsqueeze(0))
assert bias.shape == (1, 4, 4, num_heads)
@pytest.mark.parametrize("max_dist", [1, 4])
......@@ -2537,12 +2555,15 @@ def test_SpatialEncoder(max_dist, num_kernels, num_heads):
ndata = th.rand(bg.num_nodes(), 3).to(dev)
num_nodes = bg.num_nodes()
node_type = th.randint(0, 512, (num_nodes,)).to(dev)
dist = -th.ones((2, 6, 6), dtype=th.long).to(dev)
dist[0, :4, :4] = shortest_dist(g1, root=None, return_paths=False)
dist[1, :6, :6] = shortest_dist(g2, root=None, return_paths=False)
model_1 = nn.SpatialEncoder(max_dist, num_heads=num_heads).to(dev)
model_2 = nn.SpatialEncoder3d(num_kernels, num_heads=num_heads).to(dev)
model_3 = nn.SpatialEncoder3d(
num_kernels, num_heads=num_heads, max_node_type=512
).to(dev)
encoding = model_1(bg)
encoding = model_1(dist)
encoding3d_1 = model_2(bg, ndata)
encoding3d_2 = model_3(bg, ndata, node_type)
assert encoding.shape == (2, 6, 6, num_heads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment