Unverified Commit 9d9280cb authored by ZhenyuLU_Heliodore's avatar ZhenyuLU_Heliodore Committed by GitHub
Browse files

[NN] Add DegreeEncoder for graph transformer (#4742)



* Add files via upload

This file will contain several utility modules for Graph Transformer. DegreeEncoder has been implemented in this file now.

* Update graph_transformer.py

* Update nn-pytorch.rst

Add graph_transformer.DegreeEncoder in Utility Modules

* Update test_nn.py

* Update __init__.py

* Update graph_transformer.py

* Update test_nn.py

* Update graph_transformer.py

fix unused import

* Update graph_transformer.py

add module doc-string

* Update graph_transformer.py

* Update graph_transformer.py
Co-authored-by: default avatarrudongyu <ru_dongyu@outlook.com>
parent c5e83757
...@@ -114,6 +114,7 @@ Utility Modules ...@@ -114,6 +114,7 @@ Utility Modules
~dgl.nn.pytorch.explain.GNNExplainer ~dgl.nn.pytorch.explain.GNNExplainer
~dgl.nn.pytorch.explain.HeteroGNNExplainer ~dgl.nn.pytorch.explain.HeteroGNNExplainer
~dgl.nn.pytorch.utils.LabelPropagation ~dgl.nn.pytorch.utils.LabelPropagation
~dgl.nn.pytorch.graph_transformer.DegreeEncoder
~dgl.nn.pytorch.utils.LaplacianPosEnc ~dgl.nn.pytorch.utils.LaplacianPosEnc
Network Embedding Modules Network Embedding Modules
......
...@@ -10,3 +10,4 @@ from .hetero import * ...@@ -10,3 +10,4 @@ from .hetero import *
from .utils import Sequential, WeightBasis, JumpingKnowledge, LabelPropagation, LaplacianPosEnc from .utils import Sequential, WeightBasis, JumpingKnowledge, LabelPropagation, LaplacianPosEnc
from .sparse_emb import NodeEmbedding from .sparse_emb import NodeEmbedding
from .network_emb import * from .network_emb import *
from .graph_transformer import *
"""Torch modules for graph transformers."""
import torch as th
import torch.nn as nn
import dgl
class DegreeEncoder(nn.Module):
r"""Degree Encoder, as introduced in
`Do Transformers Really Perform Bad for Graph Representation?
<https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__
This module is a learnable degree embedding module.
Parameters
----------
max_degree : int
Upper bound of degrees to be encoded.
Each degree will be clamped into the range [0, ``max_degree``].
embedding_dim : int
Output dimension of embedding vectors.
direction : str, optional
Degrees of which direction to be encoded,
selected from ``in``, ``out`` and ``both``.
``both`` encodes degrees from both directions
and output the addition of them.
Default : ``both``.
Example
-------
>>> import dgl
>>> from dgl.nn import DegreeEncoder
>>> g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> degree_encoder = DegreeEncoder(5, 16)
>>> degree_embedding = degree_encoder(g)
"""
def __init__(self, max_degree, embedding_dim, direction="both"):
super(DegreeEncoder, self).__init__()
self.direction = direction
if direction == "both":
self.degree_encoder_1 = nn.Embedding(
max_degree + 1, embedding_dim, padding_idx=0
)
self.degree_encoder_2 = nn.Embedding(
max_degree + 1, embedding_dim, padding_idx=0
)
else:
self.degree_encoder = nn.Embedding(
max_degree + 1, embedding_dim, padding_idx=0
)
self.max_degree = max_degree
def forward(self, g):
"""
Parameters
----------
g : DGLGraph
A DGLGraph to be encoded. If it is a heterogeneous one,
it will be transformed into a homogeneous one first.
Returns
-------
Tensor
Return degree embedding vectors of shape :math:`(N, embedding_dim)`,
where :math:`N` is th number of nodes in the input graph.
"""
if len(g.ntypes) > 1 or len(g.etypes) > 1:
g = dgl.to_homogeneous(g)
in_degree = th.clamp(g.in_degrees(), min=0, max=self.max_degree)
out_degree = th.clamp(g.out_degrees(), min=0, max=self.max_degree)
if self.direction == "in":
degree_embedding = self.degree_encoder(in_degree)
elif self.direction == "out":
degree_embedding = self.degree_encoder(out_degree)
elif self.direction == "both":
degree_embedding = (self.degree_encoder_1(in_degree)
+ self.degree_encoder_2(out_degree))
else:
raise ValueError(
f'Supported direction options: "in", "out" and "both", '
f'but got {self.direction}'
)
return degree_embedding
...@@ -1718,6 +1718,26 @@ def test_DeepWalk(): ...@@ -1718,6 +1718,26 @@ def test_DeepWalk():
loss.backward() loss.backward()
optim.step() optim.step()
@pytest.mark.parametrize('max_degree', [2, 6])
@pytest.mark.parametrize('embedding_dim', [8, 16])
@pytest.mark.parametrize('direction', ['in', 'out', 'both'])
def test_degree_encoder(max_degree, embedding_dim, direction):
g = dgl.graph((
th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
))
# test heterograph
hg = dgl.heterograph({
('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),
('drug', 'interacts', 'gene'): (th.tensor([0, 1]), th.tensor([2, 3])),
('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))
})
model = nn.DegreeEncoder(max_degree, embedding_dim, direction=direction)
de_g = model(g)
de_hg = model(hg)
assert de_g.shape == (4, embedding_dim)
assert de_hg.shape == (10, embedding_dim)
@parametrize_idtype @parametrize_idtype
def test_MetaPath2Vec(idtype): def test_MetaPath2Vec(idtype):
dev = F.ctx() dev = F.ctx()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment