"...pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "19096c6a8e7f1fb6f97bd2b43d1e9bde80a7a47f"
Unverified Commit ed66a209 authored by Zhiteng Li's avatar Zhiteng Li Committed by GitHub
Browse files

[NN] Add a learned laplacian positional encoder (#4750)



* add a learned laplacian positional encoder

* leverage black to beautify the python code

* refine according to dongyu's comments
Co-authored-by: default avatarrudongyu <ru_dongyu@outlook.com>
parent 2bca4759
...@@ -114,6 +114,7 @@ Utility Modules ...@@ -114,6 +114,7 @@ Utility Modules
~dgl.nn.pytorch.explain.GNNExplainer ~dgl.nn.pytorch.explain.GNNExplainer
~dgl.nn.pytorch.explain.HeteroGNNExplainer ~dgl.nn.pytorch.explain.HeteroGNNExplainer
~dgl.nn.pytorch.utils.LabelPropagation ~dgl.nn.pytorch.utils.LabelPropagation
~dgl.nn.pytorch.utils.LaplacianPosEnc
Network Embedding Modules Network Embedding Modules
---------------------------------------- ----------------------------------------
......
...@@ -7,6 +7,6 @@ from .glob import * ...@@ -7,6 +7,6 @@ from .glob import *
from .softmax import * from .softmax import *
from .factory import * from .factory import *
from .hetero import * from .hetero import *
from .utils import Sequential, WeightBasis, JumpingKnowledge, LabelPropagation from .utils import Sequential, WeightBasis, JumpingKnowledge, LabelPropagation, LaplacianPosEnc
from .sparse_emb import NodeEmbedding from .sparse_emb import NodeEmbedding
from .network_emb import * from .network_emb import *
...@@ -555,3 +555,155 @@ class LabelPropagation(nn.Module): ...@@ -555,3 +555,155 @@ class LabelPropagation(nn.Module):
y[mask] = labels[mask] y[mask] = labels[mask]
return y return y
class LaplacianPosEnc(nn.Module):
r"""Laplacian Positional Encoder (LPE), as introduced in
`GraphGPS: General Powerful Scalable Graph Transformers
<https://arxiv.org/abs/2205.12454>`__
This module is a learned laplacian positional encoding module using Transformer or DeepSet.
Parameters
----------
model_type : str
Encoder model type for LPE, can only be "Transformer" or "DeepSet".
num_layer : int
Number of layers in Transformer/DeepSet Encoder.
k : int
Number of smallest non-trivial eigenvectors.
lpe_dim : int
Output size of final laplacian encoding.
n_head : int, optional
Number of heads in Transformer Encoder.
Default : 1.
batch_norm : bool, optional
If True, apply batch normalization on raw LaplacianPE.
Default : False.
num_post_layer : int, optional
If num_post_layer > 0, apply an MLP of ``num_post_layer`` layers after pooling.
Default : 0.
Example
-------
>>> import dgl
>>> from dgl import LaplacianPE
>>> from dgl.nn import LaplacianPosEnc
>>> transform = LaplacianPE(k=5, feat_name='eigvec', eigval_name='eigval', padding=True)
>>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))
>>> g = transform(g)
>>> EigVals, EigVecs = g.ndata['eigval'], g.ndata['eigvec']
>>> TransformerLPE = LaplacianPosEnc(model_type="Transformer", num_layer=3, k=5,
lpe_dim=16, n_head=4)
>>> PosEnc = TransformerLPE(EigVals, EigVecs)
>>> DeepSetLPE = LaplacianPosEnc(model_type="DeepSet", num_layer=3, k=5,
lpe_dim=16, num_post_layer=2)
>>> PosEnc = DeepSetLPE(EigVals, EigVecs)
"""
def __init__(
self,
model_type,
num_layer,
k,
lpe_dim,
n_head=1,
batch_norm=False,
num_post_layer=0,
):
super(LaplacianPosEnc, self).__init__()
self.model_type = model_type
self.linear = nn.Linear(2, lpe_dim)
if self.model_type == "Transformer":
encoder_layer = nn.TransformerEncoderLayer(
d_model=lpe_dim, nhead=n_head, batch_first=True
)
self.pe_encoder = nn.TransformerEncoder(
encoder_layer, num_layers=num_layer
)
elif self.model_type == "DeepSet":
layers = []
if num_layer == 1:
layers.append(nn.ReLU())
else:
self.linear = nn.Linear(2, 2 * lpe_dim)
layers.append(nn.ReLU())
for _ in range(num_layer - 2):
layers.append(nn.Linear(2 * lpe_dim, 2 * lpe_dim))
layers.append(nn.ReLU())
layers.append(nn.Linear(2 * lpe_dim, lpe_dim))
layers.append(nn.ReLU())
self.pe_encoder = nn.Sequential(*layers)
else:
raise ValueError(
f"model_type '{model_type}' is not allowed, must be 'Transformer'"
"or 'DeepSet'."
)
if batch_norm:
self.raw_norm = nn.BatchNorm1d(k)
else:
self.raw_norm = None
if num_post_layer > 0:
layers = []
if num_post_layer == 1:
layers.append(nn.Linear(lpe_dim, lpe_dim))
layers.append(nn.ReLU())
else:
layers.append(nn.Linear(lpe_dim, 2 * lpe_dim))
layers.append(nn.ReLU())
for _ in range(num_post_layer - 2):
layers.append(nn.Linear(2 * lpe_dim, 2 * lpe_dim))
layers.append(nn.ReLU())
layers.append(nn.Linear(2 * lpe_dim, lpe_dim))
layers.append(nn.ReLU())
self.post_mlp = nn.Sequential(*layers)
else:
self.post_mlp = None
def forward(self, EigVals, EigVecs):
r"""
Parameters
----------
EigVals : Tensor
Laplacian Eigenvalues of shape :math:`(N, k)`, k different eigenvalues repeat N times,
can be obtained by using `LaplacianPE`.
EigVecs : Tensor
Laplacian Eigenvectors of shape :math:`(N, k)`, can be obtained by using `LaplacianPE`.
Returns
-------
Tensor
Return the laplacian positional encodings of shape :math:`(N, lpe_dim)`,
where :math:`N` is the number of nodes in the input graph.
"""
PosEnc = th.cat(
(EigVecs.unsqueeze(2), EigVals.unsqueeze(2)), dim=2
).float()
empty_mask = th.isnan(PosEnc)
PosEnc[empty_mask] = 0
if self.raw_norm:
PosEnc = self.raw_norm(PosEnc)
PosEnc = self.linear(PosEnc)
if self.model_type == "Transformer":
PosEnc = self.pe_encoder(
src=PosEnc, src_key_padding_mask=empty_mask[:, :, 1]
)
else:
PosEnc = self.pe_encoder(PosEnc)
# Remove masked sequences
PosEnc[empty_mask[:, :, 1]] = 0
# Sum pooling
PosEnc = th.sum(PosEnc, 1, keepdim=False)
# MLP post pooling
if self.post_mlp:
PosEnc = self.post_mlp(PosEnc)
return PosEnc
...@@ -1731,3 +1731,24 @@ def test_MetaPath2Vec(idtype): ...@@ -1731,3 +1731,24 @@ def test_MetaPath2Vec(idtype):
model = model.to(dev) model = model.to(dev)
embeds = model.node_embed.weight embeds = model.node_embed.weight
assert embeds.shape[0] == g.num_nodes() assert embeds.shape[0] == g.num_nodes()
@pytest.mark.parametrize('num_layer', [1, 4])
@pytest.mark.parametrize('k', [3, 5])
@pytest.mark.parametrize('lpe_dim', [4, 16])
@pytest.mark.parametrize('n_head', [1, 4])
@pytest.mark.parametrize('batch_norm', [True, False])
@pytest.mark.parametrize('num_post_layer', [0, 1, 2])
def test_LaplacianPosEnc(num_layer, k, lpe_dim, n_head, batch_norm, num_post_layer):
ctx = F.ctx()
num_nodes = 4
EigVals = th.randn((num_nodes, k)).to(ctx)
EigVecs = th.randn((num_nodes, k)).to(ctx)
model = nn.LaplacianPosEnc("Transformer", num_layer, k, lpe_dim, n_head,
batch_norm, num_post_layer).to(ctx)
assert model(EigVals, EigVecs).shape == (num_nodes, lpe_dim)
model = nn.LaplacianPosEnc("DeepSet", num_layer, k, lpe_dim,
batch_norm=batch_norm, num_post_layer=num_post_layer).to(ctx)
assert model(EigVals, EigVecs).shape == (num_nodes, lpe_dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment