Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
013d1456
Unverified
Commit
013d1456
authored
Dec 07, 2020
by
Mufei Li
Committed by
GitHub
Dec 07, 2020
Browse files
[NN] Attention Retrieval for NN Modules (#2397)
* Update * Update
parent
c038b71f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
55 additions
and
8 deletions
+55
-8
python/dgl/nn/mxnet/conv/gatconv.py
python/dgl/nn/mxnet/conv/gatconv.py
+11
-2
python/dgl/nn/pytorch/conv/dotgatconv.py
python/dgl/nn/pytorch/conv/dotgatconv.py
+10
-2
python/dgl/nn/pytorch/conv/gatconv.py
python/dgl/nn/pytorch/conv/gatconv.py
+11
-2
python/dgl/nn/tensorflow/conv/gatconv.py
python/dgl/nn/tensorflow/conv/gatconv.py
+11
-2
tests/mxnet/test_nn.py
tests/mxnet/test_nn.py
+4
-0
tests/pytorch/test_nn.py
tests/pytorch/test_nn.py
+4
-0
tests/tensorflow/test_nn.py
tests/tensorflow/test_nn.py
+4
-0
No files found.
python/dgl/nn/mxnet/conv/gatconv.py
View file @
013d1456
...
@@ -201,7 +201,7 @@ class GATConv(nn.Block):
...
@@ -201,7 +201,7 @@ class GATConv(nn.Block):
"""
"""
self
.
_allow_zero_in_degree
=
set_value
self
.
_allow_zero_in_degree
=
set_value
def
forward
(
self
,
graph
,
feat
):
def
forward
(
self
,
graph
,
feat
,
get_attention
=
False
):
r
"""
r
"""
Description
Description
...
@@ -217,12 +217,17 @@ class GATConv(nn.Block):
...
@@ -217,12 +217,17 @@ class GATConv(nn.Block):
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of mxnet.NDArray is given, the pair must contain two tensors of shape
If a pair of mxnet.NDArray is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns
Returns
-------
-------
mxnet.NDArray
mxnet.NDArray
The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature.
is the number of heads, and :math:`D_{out}` is size of output feature.
mxnet.NDArray, optional
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.
Raises
Raises
------
------
...
@@ -288,4 +293,8 @@ class GATConv(nn.Block):
...
@@ -288,4 +293,8 @@ class GATConv(nn.Block):
# activation
# activation
if
self
.
activation
:
if
self
.
activation
:
rst
=
self
.
activation
(
rst
)
rst
=
self
.
activation
(
rst
)
return
rst
if
get_attention
:
return
rst
,
graph
.
edata
[
'a'
]
else
:
return
rst
python/dgl/nn/pytorch/conv/dotgatconv.py
View file @
013d1456
...
@@ -114,7 +114,7 @@ class DotGatConv(nn.Module):
...
@@ -114,7 +114,7 @@ class DotGatConv(nn.Module):
else
:
else
:
self
.
fc
=
nn
.
Linear
(
self
.
_in_src_feats
,
self
.
_out_feats
,
bias
=
False
)
self
.
fc
=
nn
.
Linear
(
self
.
_in_src_feats
,
self
.
_out_feats
,
bias
=
False
)
def
forward
(
self
,
graph
,
feat
):
def
forward
(
self
,
graph
,
feat
,
get_attention
=
False
):
r
"""
r
"""
Description
Description
...
@@ -130,12 +130,17 @@ class DotGatConv(nn.Module):
...
@@ -130,12 +130,17 @@ class DotGatConv(nn.Module):
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns
Returns
-------
-------
torch.Tensor
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` is size
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` is size
of output feature.
of output feature.
torch.Tensor, optional
The attention values of shape :math:`(E, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.
Raises
Raises
------
------
...
@@ -187,4 +192,7 @@ class DotGatConv(nn.Module):
...
@@ -187,4 +192,7 @@ class DotGatConv(nn.Module):
# output results to the destination nodes
# output results to the destination nodes
rst
=
graph
.
dstdata
[
'agg_u'
]
rst
=
graph
.
dstdata
[
'agg_u'
]
return
rst
if
get_attention
:
return
rst
,
graph
.
edata
[
'sa'
]
else
:
return
rst
python/dgl/nn/pytorch/conv/gatconv.py
View file @
013d1456
...
@@ -208,7 +208,7 @@ class GATConv(nn.Module):
...
@@ -208,7 +208,7 @@ class GATConv(nn.Module):
"""
"""
self
.
_allow_zero_in_degree
=
set_value
self
.
_allow_zero_in_degree
=
set_value
def
forward
(
self
,
graph
,
feat
):
def
forward
(
self
,
graph
,
feat
,
get_attention
=
False
):
r
"""
r
"""
Description
Description
...
@@ -224,12 +224,17 @@ class GATConv(nn.Module):
...
@@ -224,12 +224,17 @@ class GATConv(nn.Module):
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns
Returns
-------
-------
torch.Tensor
torch.Tensor
The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature.
is the number of heads, and :math:`D_{out}` is size of output feature.
torch.Tensor, optional
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.
Raises
Raises
------
------
...
@@ -294,4 +299,8 @@ class GATConv(nn.Module):
...
@@ -294,4 +299,8 @@ class GATConv(nn.Module):
# activation
# activation
if
self
.
activation
:
if
self
.
activation
:
rst
=
self
.
activation
(
rst
)
rst
=
self
.
activation
(
rst
)
return
rst
if
get_attention
:
return
rst
,
graph
.
edata
[
'a'
]
else
:
return
rst
python/dgl/nn/tensorflow/conv/gatconv.py
View file @
013d1456
...
@@ -195,7 +195,7 @@ class GATConv(layers.Layer):
...
@@ -195,7 +195,7 @@ class GATConv(layers.Layer):
"""
"""
self
.
_allow_zero_in_degree
=
set_value
self
.
_allow_zero_in_degree
=
set_value
def
call
(
self
,
graph
,
feat
):
def
call
(
self
,
graph
,
feat
,
get_attention
=
False
):
r
"""
r
"""
Description
Description
...
@@ -211,12 +211,17 @@ class GATConv(layers.Layer):
...
@@ -211,12 +211,17 @@ class GATConv(layers.Layer):
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of tf.Tensor is given, the pair must contain two tensors of shape
If a pair of tf.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns
Returns
-------
-------
tf.Tensor
tf.Tensor
The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature.
is the number of heads, and :math:`D_{out}` is size of output feature.
tf.Tensor, optional
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.
Raises
Raises
------
------
...
@@ -282,4 +287,8 @@ class GATConv(layers.Layer):
...
@@ -282,4 +287,8 @@ class GATConv(layers.Layer):
# activation
# activation
if
self
.
activation
:
if
self
.
activation
:
rst
=
self
.
activation
(
rst
)
rst
=
self
.
activation
(
rst
)
return
rst
if
get_attention
:
return
rst
,
graph
.
edata
[
'a'
]
else
:
return
rst
tests/mxnet/test_nn.py
View file @
013d1456
...
@@ -167,6 +167,8 @@ def test_gat_conv(g, idtype):
...
@@ -167,6 +167,8 @@ def test_gat_conv(g, idtype):
feat
=
F
.
randn
((
g
.
number_of_nodes
(),
10
))
feat
=
F
.
randn
((
g
.
number_of_nodes
(),
10
))
h
=
gat
(
g
,
feat
)
h
=
gat
(
g
,
feat
)
assert
h
.
shape
==
(
g
.
number_of_nodes
(),
5
,
20
)
assert
h
.
shape
==
(
g
.
number_of_nodes
(),
5
,
20
)
_
,
a
=
gat
(
g
,
feat
,
True
)
assert
a
.
shape
==
(
g
.
number_of_edges
(),
5
,
1
)
@
parametrize_dtype
@
parametrize_dtype
@
pytest
.
mark
.
parametrize
(
'g'
,
get_cases
([
'bipartite'
],
exclude
=
[
'zero-degree'
]))
@
pytest
.
mark
.
parametrize
(
'g'
,
get_cases
([
'bipartite'
],
exclude
=
[
'zero-degree'
]))
...
@@ -178,6 +180,8 @@ def test_gat_conv_bi(g, idtype):
...
@@ -178,6 +180,8 @@ def test_gat_conv_bi(g, idtype):
feat
=
(
F
.
randn
((
g
.
number_of_src_nodes
(),
5
)),
F
.
randn
((
g
.
number_of_dst_nodes
(),
5
)))
feat
=
(
F
.
randn
((
g
.
number_of_src_nodes
(),
5
)),
F
.
randn
((
g
.
number_of_dst_nodes
(),
5
)))
h
=
gat
(
g
,
feat
)
h
=
gat
(
g
,
feat
)
assert
h
.
shape
==
(
g
.
number_of_dst_nodes
(),
4
,
2
)
assert
h
.
shape
==
(
g
.
number_of_dst_nodes
(),
4
,
2
)
_
,
a
=
gat
(
g
,
feat
,
True
)
assert
a
.
shape
==
(
g
.
number_of_edges
(),
4
,
1
)
@
parametrize_dtype
@
parametrize_dtype
@
pytest
.
mark
.
parametrize
(
'g'
,
get_cases
([
'homo'
,
'block-bipartite'
]))
@
pytest
.
mark
.
parametrize
(
'g'
,
get_cases
([
'homo'
,
'block-bipartite'
]))
...
...
tests/pytorch/test_nn.py
View file @
013d1456
...
@@ -379,6 +379,8 @@ def test_gat_conv(g, idtype):
...
@@ -379,6 +379,8 @@ def test_gat_conv(g, idtype):
gat
=
gat
.
to
(
ctx
)
gat
=
gat
.
to
(
ctx
)
h
=
gat
(
g
,
feat
)
h
=
gat
(
g
,
feat
)
assert
h
.
shape
==
(
g
.
number_of_nodes
(),
4
,
2
)
assert
h
.
shape
==
(
g
.
number_of_nodes
(),
4
,
2
)
_
,
a
=
gat
(
g
,
feat
,
get_attention
=
True
)
assert
a
.
shape
==
(
g
.
number_of_edges
(),
4
,
1
)
@
parametrize_dtype
@
parametrize_dtype
@
pytest
.
mark
.
parametrize
(
'g'
,
get_cases
([
'bipartite'
],
exclude
=
[
'zero-degree'
]))
@
pytest
.
mark
.
parametrize
(
'g'
,
get_cases
([
'bipartite'
],
exclude
=
[
'zero-degree'
]))
...
@@ -390,6 +392,8 @@ def test_gat_conv_bi(g, idtype):
...
@@ -390,6 +392,8 @@ def test_gat_conv_bi(g, idtype):
gat
=
gat
.
to
(
ctx
)
gat
=
gat
.
to
(
ctx
)
h
=
gat
(
g
,
feat
)
h
=
gat
(
g
,
feat
)
assert
h
.
shape
==
(
g
.
number_of_dst_nodes
(),
4
,
2
)
assert
h
.
shape
==
(
g
.
number_of_dst_nodes
(),
4
,
2
)
_
,
a
=
gat
(
g
,
feat
,
get_attention
=
True
)
assert
a
.
shape
==
(
g
.
number_of_edges
(),
4
,
1
)
@
parametrize_dtype
@
parametrize_dtype
@
pytest
.
mark
.
parametrize
(
'g'
,
get_cases
([
'homo'
,
'block-bipartite'
]))
@
pytest
.
mark
.
parametrize
(
'g'
,
get_cases
([
'homo'
,
'block-bipartite'
]))
...
...
tests/tensorflow/test_nn.py
View file @
013d1456
...
@@ -266,6 +266,8 @@ def test_gat_conv(g, idtype):
...
@@ -266,6 +266,8 @@ def test_gat_conv(g, idtype):
feat
=
F
.
randn
((
g
.
number_of_nodes
(),
5
))
feat
=
F
.
randn
((
g
.
number_of_nodes
(),
5
))
h
=
gat
(
g
,
feat
)
h
=
gat
(
g
,
feat
)
assert
h
.
shape
==
(
g
.
number_of_nodes
(),
4
,
2
)
assert
h
.
shape
==
(
g
.
number_of_nodes
(),
4
,
2
)
_
,
a
=
gat
(
g
,
feat
,
get_attention
=
True
)
assert
a
.
shape
==
(
g
.
number_of_edges
(),
4
,
1
)
@
parametrize_dtype
@
parametrize_dtype
@
pytest
.
mark
.
parametrize
(
'g'
,
get_cases
([
'bipartite'
],
exclude
=
[
'zero-degree'
]))
@
pytest
.
mark
.
parametrize
(
'g'
,
get_cases
([
'bipartite'
],
exclude
=
[
'zero-degree'
]))
...
@@ -276,6 +278,8 @@ def test_gat_conv_bi(g, idtype):
...
@@ -276,6 +278,8 @@ def test_gat_conv_bi(g, idtype):
feat
=
(
F
.
randn
((
g
.
number_of_src_nodes
(),
5
)),
F
.
randn
((
g
.
number_of_dst_nodes
(),
5
)))
feat
=
(
F
.
randn
((
g
.
number_of_src_nodes
(),
5
)),
F
.
randn
((
g
.
number_of_dst_nodes
(),
5
)))
h
=
gat
(
g
,
feat
)
h
=
gat
(
g
,
feat
)
assert
h
.
shape
==
(
g
.
number_of_dst_nodes
(),
4
,
2
)
assert
h
.
shape
==
(
g
.
number_of_dst_nodes
(),
4
,
2
)
_
,
a
=
gat
(
g
,
feat
,
get_attention
=
True
)
assert
a
.
shape
==
(
g
.
number_of_edges
(),
4
,
1
)
@
parametrize_dtype
@
parametrize_dtype
@
pytest
.
mark
.
parametrize
(
'g'
,
get_cases
([
'homo'
,
'block-bipartite'
]))
@
pytest
.
mark
.
parametrize
(
'g'
,
get_cases
([
'homo'
,
'block-bipartite'
]))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment