Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
d2ef2433
Unverified
Commit
d2ef2433
authored
Dec 02, 2021
by
Mufei Li
Committed by
GitHub
Dec 02, 2021
Browse files
[NN] EdgePredictor (#3518)
* Update * Update * Fix * Update * Update * update * Fix test * CI * CI
parent
490c5a8d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
209 additions
and
0 deletions
+209
-0
docs/source/api/python/nn.pytorch.rst
docs/source/api/python/nn.pytorch.rst
+14
-0
python/dgl/nn/pytorch/__init__.py
python/dgl/nn/pytorch/__init__.py
+1
-0
python/dgl/nn/pytorch/link.py
python/dgl/nn/pytorch/link.py
+175
-0
tests/pytorch/test_nn.py
tests/pytorch/test_nn.py
+19
-0
No files found.
docs/source/api/python/nn.pytorch.rst
View file @
d2ef2433
...
...
@@ -267,6 +267,20 @@ SetTransformerDecoder
:members:
:show-inheritance:
.. _apinn-pytorch-link
Predictor and Score Functions for Link Prediction
-------------------------------------------------
.. automodule:: dgl.nn.pytorch.link
EdgePredictor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.pytorch.link.EdgePredictor
:members: forward, reset_parameters
:show-inheritance:
Heterogeneous Graph Convolution Module
----------------------------------------
...
...
python/dgl/nn/pytorch/__init__.py
View file @
d2ef2433
...
...
@@ -7,3 +7,4 @@ from .factory import *
from
.hetero
import
*
from
.utils
import
Sequential
,
WeightBasis
,
JumpingKnowledge
from
.sparse_emb
import
NodeEmbedding
from
.link
import
*
python/dgl/nn/pytorch/link.py
0 → 100644
View file @
d2ef2433
"""Torch modules for link prediction."""
# pylint: disable= no-member, arguments-differ, invalid-name, W0235
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
__all__
=
[
'EdgePredictor'
]
class
EdgePredictor
(
nn
.
Module
):
r
"""
Description
-----------
Predictor/score function for pairs of node representations. Given a pair of node
representations, :math:`h_i` and :math:`h_j`, it combines them with
**dot product**
.. math::
h_i^{T} h_j
or **cosine similarity**
.. math::
\frac{h_i^{T} h_j}{{\| h_i \|}_2 \cdot {\| h_j \|}_2}
or **elementwise product**
.. math::
h_i \odot h_j
or **concatenation**
.. math::
h_i \Vert h_j
Optionally, it passes the combined results to a linear layer for the final prediction.
Parameters
----------
op : str
The operation to apply. It can be 'dot', 'cos', 'ele', or 'cat',
corresponding to the equations above in order.
in_feats : int, optional
The input feature size of :math:`h_i` and :math:`h_j`. It is required
only if a linear layer is to be applied.
out_feats : int, optional
The output feature size. It is reuiqred only if a linear layer is to be applied.
bias : bool, optional
Whether to use bias for the linear layer if it applies.
Examples
--------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import EdgePredictor
>>> num_nodes = 2
>>> num_edges = 3
>>> in_feats = 4
>>> g = dgl.rand_graph(num_nodes=num_nodes, num_edges=num_edges)
>>> h = th.randn(num_nodes, in_feats)
>>> src, dst = g.edges()
>>> h_src = h[src]
>>> h_dst = h[dst]
Case1: dot product
>>> predictor = EdgePredictor('dot')
>>> predictor(h_src, h_dst).shape
torch.Size([3, 1])
>>> predictor = EdgePredictor('dot', in_feats, out_feats=3)
>>> predictor.reset_parameters()
>>> predictor(h_src, h_dst).shape
torch.Size([3, 3])
Case2: cosine similarity
>>> predictor = EdgePredictor('cos')
>>> predictor(h_src, h_dst).shape
torch.Size([3, 1])
>>> predictor = EdgePredictor('cos', in_feats, out_feats=3)
>>> predictor.reset_parameters()
>>> predictor(h_src, h_dst).shape
torch.Size([3, 3])
Case3: elementwise product
>>> predictor = EdgePredictor('ele')
>>> predictor(h_src, h_dst).shape
torch.Size([3, 4])
>>> predictor = EdgePredictor('ele', in_feats, out_feats=3)
>>> predictor.reset_parameters()
>>> predictor(h_src, h_dst).shape
torch.Size([3, 3])
Case4: concatenation
>>> predictor = EdgePredictor('cat')
>>> predictor(h_src, h_dst).shape
torch.Size([3, 8])
>>> predictor = EdgePredictor('cat', in_feats, out_feats=3)
>>> predictor.reset_parameters()
>>> predictor(h_src, h_dst).shape
torch.Size([3, 3])
"""
def
__init__
(
self
,
op
,
in_feats
=
None
,
out_feats
=
None
,
bias
=
False
):
super
(
EdgePredictor
,
self
).
__init__
()
assert
op
in
[
'dot'
,
'cos'
,
'ele'
,
'cat'
],
\
"Expect op to be in ['dot', 'cos', 'ele', 'cat'], got {}"
.
format
(
op
)
self
.
op
=
op
if
(
in_feats
is
not
None
)
and
(
out_feats
is
not
None
):
if
op
in
[
'dot'
,
'cos'
]:
in_feats
=
1
elif
op
==
'cat'
:
in_feats
=
2
*
in_feats
self
.
linear
=
nn
.
Linear
(
in_feats
,
out_feats
,
bias
=
bias
)
else
:
self
.
linear
=
None
def
reset_parameters
(
self
):
r
"""
Description
-----------
Reinitialize learnable parameters.
"""
if
self
.
linear
is
not
None
:
self
.
linear
.
reset_parameters
()
def
forward
(
self
,
h_src
,
h_dst
):
r
"""
Description
-----------
Predict for pairs of node representations.
Parameters
----------
h_src : torch.Tensor
Source node features. The tensor is of shape :math:`(E, D_{in})`,
where :math:`E` is the number of edges/node pairs, and :math:`D_{in}`
is the input feature size.
h_dst : torch.Tensor
Destination node features. The tensor is of shape :math:`(E, D_{in})`,
where :math:`E` is the number of edges/node pairs, and :math:`D_{in}`
is the input feature size.
Returns
-------
torch.Tensor
The output features.
"""
if
self
.
op
==
'dot'
:
N
,
D
=
h_src
.
shape
h
=
torch
.
bmm
(
h_src
.
view
(
N
,
1
,
D
),
h_dst
.
view
(
N
,
D
,
1
)).
squeeze
(
-
1
)
elif
self
.
op
==
'cos'
:
h
=
F
.
cosine_similarity
(
h_src
,
h_dst
).
unsqueeze
(
-
1
)
elif
self
.
op
==
'ele'
:
h
=
h_src
*
h_dst
else
:
h
=
torch
.
cat
([
h_src
,
h_dst
],
dim
=-
1
)
if
self
.
linear
is
not
None
:
h
=
self
.
linear
(
h
)
return
h
tests/pytorch/test_nn.py
View file @
d2ef2433
...
...
@@ -1304,6 +1304,25 @@ def test_jumping_knowledge():
model
.
reset_parameters
()
assert
model
(
feat_list
).
shape
==
(
num_nodes
,
num_feats
)
@
pytest
.
mark
.
parametrize
(
'op'
,
[
'dot'
,
'cos'
,
'ele'
,
'cat'
])
def
test_edge_predictor
(
op
):
ctx
=
F
.
ctx
()
num_pairs
=
3
in_feats
=
4
out_feats
=
5
h_src
=
th
.
randn
((
num_pairs
,
in_feats
)).
to
(
ctx
)
h_dst
=
th
.
randn
((
num_pairs
,
in_feats
)).
to
(
ctx
)
pred
=
nn
.
EdgePredictor
(
op
)
if
op
in
[
'dot'
,
'cos'
]:
assert
pred
(
h_src
,
h_dst
).
shape
==
(
num_pairs
,
1
)
elif
op
==
'ele'
:
assert
pred
(
h_src
,
h_dst
).
shape
==
(
num_pairs
,
in_feats
)
else
:
assert
pred
(
h_src
,
h_dst
).
shape
==
(
num_pairs
,
2
*
in_feats
)
pred
=
nn
.
EdgePredictor
(
op
,
in_feats
,
out_feats
,
bias
=
True
).
to
(
ctx
)
assert
pred
(
h_src
,
h_dst
).
shape
==
(
num_pairs
,
out_feats
)
if
__name__
==
'__main__'
:
test_graph_conv
()
test_graph_conv_e_weight
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment