Unverified Commit 93ecfa31 authored by ZhenyuLU_Heliodore's avatar ZhenyuLU_Heliodore Committed by GitHub
Browse files

[NN] Add PathEncoder (#4956)



* Add PathEncoder to transformer.py

* add blank line at the and

* rename variabl sp to shortest_path

* Fixed corresponding problems

* Fixed certain bugs when running on CUDA

* changed clamp min from 0 to 1
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-14-146.ap-northeast-1.compute.internal>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent b84de903
...@@ -118,6 +118,7 @@ Utility Modules ...@@ -118,6 +118,7 @@ Utility Modules
~dgl.nn.pytorch.graph_transformer.DegreeEncoder ~dgl.nn.pytorch.graph_transformer.DegreeEncoder
~dgl.nn.pytorch.utils.LaplacianPosEnc ~dgl.nn.pytorch.utils.LaplacianPosEnc
~dgl.nn.pytorch.graph_transformer.BiasedMultiheadAttention ~dgl.nn.pytorch.graph_transformer.BiasedMultiheadAttention
~dgl.nn.pytorch.graph_transformer.PathEncoder
Network Embedding Modules Network Embedding Modules
---------------------------------------- ----------------------------------------
......
...@@ -2,9 +2,14 @@ ...@@ -2,9 +2,14 @@
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl from ...convert import to_homogeneous
from ...batch import unbatch
from ...transforms import shortest_dist
__all__ = ["DegreeEncoder",
"BiasedMultiheadAttention",
"PathEncoder"]
__all__ = ["DegreeEncoder", "BiasedMultiheadAttention"]
class DegreeEncoder(nn.Module): class DegreeEncoder(nn.Module):
r"""Degree Encoder, as introduced in r"""Degree Encoder, as introduced in
...@@ -67,7 +72,7 @@ class DegreeEncoder(nn.Module): ...@@ -67,7 +72,7 @@ class DegreeEncoder(nn.Module):
where :math:`N` is th number of nodes in the input graph. where :math:`N` is th number of nodes in the input graph.
""" """
if len(g.ntypes) > 1 or len(g.etypes) > 1: if len(g.ntypes) > 1 or len(g.etypes) > 1:
g = dgl.to_homogeneous(g) g = to_homogeneous(g)
in_degree = th.clamp(g.in_degrees(), min=0, max=self.max_degree) in_degree = th.clamp(g.in_degrees(), min=0, max=self.max_degree)
out_degree = th.clamp(g.out_degrees(), min=0, max=self.max_degree) out_degree = th.clamp(g.out_degrees(), min=0, max=self.max_degree)
...@@ -219,3 +224,115 @@ class BiasedMultiheadAttention(nn.Module): ...@@ -219,3 +224,115 @@ class BiasedMultiheadAttention(nn.Module):
attn = self.out_proj(attn.reshape(N, bsz, self.feat_size).transpose(0, 1)) attn = self.out_proj(attn.reshape(N, bsz, self.feat_size).transpose(0, 1))
return attn return attn
class PathEncoder(nn.Module):
r"""Path Encoder, as introduced in Edge Encoding of
`Do Transformers Really Perform Bad for Graph Representation?
<https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__
This module is a learnable path embedding module and encodes the shortest
path between each pair of nodes as attention bias.
Parameters
----------
max_len : int
Maximum number of edges in each path to be encoded.
Exceeding part of each path will be truncated, i.e.
truncating edges with serial number no less than :attr:`max_len`.
feat_dim : int
Dimension of edge features in the input graph.
num_heads : int, optional
Number of attention heads if multi-head attention mechanism is applied.
Default : 1.
Examples
--------
>>> import torch as th
>>> import dgl
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
>>> g = dgl.graph((u, v))
>>> edata = th.rand(8, 16)
>>> path_encoder = dgl.PathEncoder(2, 16, 8)
>>> out = path_encoder(g, edata)
"""
def __init__(self, max_len, feat_dim, num_heads=1):
super().__init__()
self.max_len = max_len
self.feat_dim = feat_dim
self.num_heads = num_heads
self.embedding_table = nn.Embedding(max_len * num_heads, feat_dim)
def forward(self, g, edge_feat):
"""
Parameters
----------
g : DGLGraph
A DGLGraph to be encoded, which must be a homogeneous one.
edge_feat : torch.Tensor
The input edge feature of shape :math:`(E, feat_dim)`,
where :math:`E` is the number of edges in the input graph.
Returns
-------
torch.Tensor
Return attention bias as path encoding,
of shape :math:`(batch_size, N, N, num_heads)`,
where :math:`N` is the maximum number of nodes
and batch_size is the batch size of the input graph.
"""
g_list = unbatch(g)
sum_num_edges = 0
max_num_nodes = th.max(g.batch_num_nodes())
path_encoding = []
for ubg in g_list:
num_nodes = ubg.num_nodes()
num_edges = ubg.num_edges()
edata = edge_feat[sum_num_edges: (sum_num_edges + num_edges)]
sum_num_edges = sum_num_edges + num_edges
edata = th.cat(
(edata, th.zeros(1, self.feat_dim).to(edata.device)),
dim=0
)
_, path = shortest_dist(ubg, root=None, return_paths=True)
path_len = min(self.max_len, path.size(dim=2))
# shape: [n, n, l], n = num_nodes, l = path_len
shortest_path = path[:, :, 0: path_len]
# shape: [n, n]
shortest_distance = th.clamp(
shortest_dist(ubg, root=None, return_paths=False),
min=1,
max=path_len
)
# shape: [n, n, l, d], d = feat_dim
path_data = edata[shortest_path]
# shape: [l, h], h = num_heads
embedding_idx = th.reshape(
th.arange(self.num_heads * path_len),
(path_len, self.num_heads)
).to(next(self.embedding_table.parameters()).device)
# shape: [d, l, h]
edge_embedding = th.permute(
self.embedding_table(embedding_idx), (2, 0, 1)
)
# [n, n, l, d] einsum [d, l, h] -> [n, n, h]
# [n, n, h] -> [N, N, h], N = max_num_nodes, padded with -inf
sub_encoding = th.full(
(max_num_nodes, max_num_nodes, self.num_heads),
float('-inf')
)
sub_encoding[0: num_nodes, 0: num_nodes] = th.div(
th.einsum(
'xyld,dlh->xyh', path_data, edge_embedding
).permute(2, 0, 1),
shortest_distance
).permute(1, 2, 0)
path_encoding.append(sub_encoding)
return th.stack(path_encoding, dim=0)
...@@ -1516,7 +1516,7 @@ def test_hgt(idtype, in_size, num_heads): ...@@ -1516,7 +1516,7 @@ def test_hgt(idtype, in_size, num_heads):
sorted_y = m(sorted_g, sorted_x, sorted_ntype, sorted_etype, presorted=False) sorted_y = m(sorted_g, sorted_x, sorted_ntype, sorted_etype, presorted=False)
assert sorted_y.shape == (g.num_nodes(), head_size * num_heads) assert sorted_y.shape == (g.num_nodes(), head_size * num_heads)
# mini-batch # mini-batch
train_idx = th.randperm(100, dtype=idtype)[:10] train_idx = th.randint(0, 100, (10, ), dtype = idtype)
sampler = dgl.dataloading.NeighborSampler([-1]) sampler = dgl.dataloading.NeighborSampler([-1])
train_loader = dgl.dataloading.DataLoader(g, train_idx.to(dev), sampler, train_loader = dgl.dataloading.DataLoader(g, train_idx.to(dev), sampler,
batch_size=8, device=dev, batch_size=8, device=dev,
...@@ -1801,3 +1801,23 @@ def test_BiasedMultiheadAttention(feat_size, num_heads, bias, attn_bias_type, at ...@@ -1801,3 +1801,23 @@ def test_BiasedMultiheadAttention(feat_size, num_heads, bias, attn_bias_type, at
out = net(ndata, attn_bias, attn_mask) out = net(ndata, attn_bias, attn_mask)
assert out.shape == (16, 100, feat_size) assert out.shape == (16, 100, feat_size)
@pytest.mark.parametrize('max_len', [1, 4])
@pytest.mark.parametrize('feat_dim', [16])
@pytest.mark.parametrize('num_heads', [1, 8])
def test_PathEncoder(max_len, feat_dim, num_heads):
dev = F.ctx()
g1 = dgl.graph((
th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
)).to(dev)
g2 = dgl.graph((
th.tensor([0, 1, 2, 3, 2, 5]),
th.tensor([1, 2, 3, 4, 0, 3])
)).to(dev)
bg = dgl.batch([g1, g2])
edge_feat = th.rand(bg.num_edges(), feat_dim).to(dev)
model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev)
bias = model(bg, edge_feat)
assert bias.shape == (2, 6, 6, num_heads)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment