"docs/source/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "94d3c6cb9f81d232cce70d4a27a9cb4e59d18534"
Unverified Commit d2ef2433 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[NN] EdgePredictor (#3518)

* Update

* Update

* Fix

* Update

* Update

* update

* Fix test

* CI

* CI
parent 490c5a8d
...@@ -267,6 +267,20 @@ SetTransformerDecoder ...@@ -267,6 +267,20 @@ SetTransformerDecoder
:members: :members:
:show-inheritance: :show-inheritance:
.. _apinn-pytorch-link
Predictor and Score Functions for Link Prediction
-------------------------------------------------
.. automodule:: dgl.nn.pytorch.link
EdgePredictor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.pytorch.link.EdgePredictor
:members: forward, reset_parameters
:show-inheritance:
Heterogeneous Graph Convolution Module Heterogeneous Graph Convolution Module
---------------------------------------- ----------------------------------------
......
...@@ -7,3 +7,4 @@ from .factory import * ...@@ -7,3 +7,4 @@ from .factory import *
from .hetero import * from .hetero import *
from .utils import Sequential, WeightBasis, JumpingKnowledge from .utils import Sequential, WeightBasis, JumpingKnowledge
from .sparse_emb import NodeEmbedding from .sparse_emb import NodeEmbedding
from .link import *
"""Torch modules for link prediction."""
# pylint: disable= no-member, arguments-differ, invalid-name, W0235
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['EdgePredictor']
class EdgePredictor(nn.Module):
r"""
Description
-----------
Predictor/score function for pairs of node representations. Given a pair of node
representations, :math:`h_i` and :math:`h_j`, it combines them with
**dot product**
.. math::
h_i^{T} h_j
or **cosine similarity**
.. math::
\frac{h_i^{T} h_j}{{\| h_i \|}_2 \cdot {\| h_j \|}_2}
or **elementwise product**
.. math::
h_i \odot h_j
or **concatenation**
.. math::
h_i \Vert h_j
Optionally, it passes the combined results to a linear layer for the final prediction.
Parameters
----------
op : str
The operation to apply. It can be 'dot', 'cos', 'ele', or 'cat',
corresponding to the equations above in order.
in_feats : int, optional
The input feature size of :math:`h_i` and :math:`h_j`. It is required
only if a linear layer is to be applied.
out_feats : int, optional
The output feature size. It is reuiqred only if a linear layer is to be applied.
bias : bool, optional
Whether to use bias for the linear layer if it applies.
Examples
--------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import EdgePredictor
>>> num_nodes = 2
>>> num_edges = 3
>>> in_feats = 4
>>> g = dgl.rand_graph(num_nodes=num_nodes, num_edges=num_edges)
>>> h = th.randn(num_nodes, in_feats)
>>> src, dst = g.edges()
>>> h_src = h[src]
>>> h_dst = h[dst]
Case1: dot product
>>> predictor = EdgePredictor('dot')
>>> predictor(h_src, h_dst).shape
torch.Size([3, 1])
>>> predictor = EdgePredictor('dot', in_feats, out_feats=3)
>>> predictor.reset_parameters()
>>> predictor(h_src, h_dst).shape
torch.Size([3, 3])
Case2: cosine similarity
>>> predictor = EdgePredictor('cos')
>>> predictor(h_src, h_dst).shape
torch.Size([3, 1])
>>> predictor = EdgePredictor('cos', in_feats, out_feats=3)
>>> predictor.reset_parameters()
>>> predictor(h_src, h_dst).shape
torch.Size([3, 3])
Case3: elementwise product
>>> predictor = EdgePredictor('ele')
>>> predictor(h_src, h_dst).shape
torch.Size([3, 4])
>>> predictor = EdgePredictor('ele', in_feats, out_feats=3)
>>> predictor.reset_parameters()
>>> predictor(h_src, h_dst).shape
torch.Size([3, 3])
Case4: concatenation
>>> predictor = EdgePredictor('cat')
>>> predictor(h_src, h_dst).shape
torch.Size([3, 8])
>>> predictor = EdgePredictor('cat', in_feats, out_feats=3)
>>> predictor.reset_parameters()
>>> predictor(h_src, h_dst).shape
torch.Size([3, 3])
"""
def __init__(self,
op,
in_feats=None,
out_feats=None,
bias=False):
super(EdgePredictor, self).__init__()
assert op in ['dot', 'cos', 'ele', 'cat'], \
"Expect op to be in ['dot', 'cos', 'ele', 'cat'], got {}".format(op)
self.op = op
if (in_feats is not None) and (out_feats is not None):
if op in ['dot', 'cos']:
in_feats = 1
elif op == 'cat':
in_feats = 2 * in_feats
self.linear = nn.Linear(in_feats, out_feats, bias=bias)
else:
self.linear = None
def reset_parameters(self):
r"""
Description
-----------
Reinitialize learnable parameters.
"""
if self.linear is not None:
self.linear.reset_parameters()
def forward(self, h_src, h_dst):
r"""
Description
-----------
Predict for pairs of node representations.
Parameters
----------
h_src : torch.Tensor
Source node features. The tensor is of shape :math:`(E, D_{in})`,
where :math:`E` is the number of edges/node pairs, and :math:`D_{in}`
is the input feature size.
h_dst : torch.Tensor
Destination node features. The tensor is of shape :math:`(E, D_{in})`,
where :math:`E` is the number of edges/node pairs, and :math:`D_{in}`
is the input feature size.
Returns
-------
torch.Tensor
The output features.
"""
if self.op == 'dot':
N, D = h_src.shape
h = torch.bmm(h_src.view(N, 1, D), h_dst.view(N, D, 1)).squeeze(-1)
elif self.op == 'cos':
h = F.cosine_similarity(h_src, h_dst).unsqueeze(-1)
elif self.op == 'ele':
h = h_src * h_dst
else:
h = torch.cat([h_src, h_dst], dim=-1)
if self.linear is not None:
h = self.linear(h)
return h
...@@ -1304,6 +1304,25 @@ def test_jumping_knowledge(): ...@@ -1304,6 +1304,25 @@ def test_jumping_knowledge():
model.reset_parameters() model.reset_parameters()
assert model(feat_list).shape == (num_nodes, num_feats) assert model(feat_list).shape == (num_nodes, num_feats)
@pytest.mark.parametrize('op', ['dot', 'cos', 'ele', 'cat'])
def test_edge_predictor(op):
ctx = F.ctx()
num_pairs = 3
in_feats = 4
out_feats = 5
h_src = th.randn((num_pairs, in_feats)).to(ctx)
h_dst = th.randn((num_pairs, in_feats)).to(ctx)
pred = nn.EdgePredictor(op)
if op in ['dot', 'cos']:
assert pred(h_src, h_dst).shape == (num_pairs, 1)
elif op == 'ele':
assert pred(h_src, h_dst).shape == (num_pairs, in_feats)
else:
assert pred(h_src, h_dst).shape == (num_pairs, 2 * in_feats)
pred = nn.EdgePredictor(op, in_feats, out_feats, bias=True).to(ctx)
assert pred(h_src, h_dst).shape == (num_pairs, out_feats)
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_graph_conv_e_weight() test_graph_conv_e_weight()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment