Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
aa884d43
Unverified
Commit
aa884d43
authored
Jan 25, 2021
by
Zihao Ye
Committed by
GitHub
Jan 25, 2021
Browse files
[doc][fix] Improve the docstring and fix its behavior in DGL's kernel (#2563)
* upd * fix * lint * fix * upd
parent
a6abffe3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
4 deletions
+41
-4
python/dgl/backend/pytorch/sparse.py
python/dgl/backend/pytorch/sparse.py
+32
-4
python/dgl/heterograph.py
python/dgl/heterograph.py
+6
-0
python/dgl/transform.py
python/dgl/transform.py
+3
-0
No files found.
python/dgl/backend/pytorch/sparse.py
View file @
aa884d43
...
@@ -5,6 +5,34 @@ from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp
...
@@ -5,6 +5,34 @@ from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp
__all__
=
[
'gspmm'
,
'gsddmm'
,
'edge_softmax'
,
'segment_reduce'
]
__all__
=
[
'gspmm'
,
'gsddmm'
,
'edge_softmax'
,
'segment_reduce'
]
_inverse_format
=
{
'coo'
:
'coo'
,
'csr'
:
'csc'
,
'csc'
:
'csr'
}
def
_reverse
(
gidx
):
"""Reverse the given graph index while retaining its formats.
Parameters
----------
gidx: HeteroGraphIndex
Return
------
HeteroGraphIndex
"""
g_rev
=
gidx
.
reverse
()
original_formats_dict
=
gidx
.
formats
()
original_formats
=
original_formats_dict
[
'created'
]
+
\
original_formats_dict
[
'not created'
]
g_rev
=
g_rev
.
formats
([
_inverse_format
[
fmt
]
for
fmt
in
original_formats
])
return
g_rev
def
_reduce_grad
(
grad
,
shape
):
def
_reduce_grad
(
grad
,
shape
):
"""Reduce gradient on the broadcast dimension
"""Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on
If there is broadcast in forward pass, gradients need to be reduced on
...
@@ -71,7 +99,7 @@ class GSpMM(th.autograd.Function):
...
@@ -71,7 +99,7 @@ class GSpMM(th.autograd.Function):
gidx
,
op
,
reduce_op
=
ctx
.
backward_cache
gidx
,
op
,
reduce_op
=
ctx
.
backward_cache
X
,
Y
,
argX
,
argY
=
ctx
.
saved_tensors
X
,
Y
,
argX
,
argY
=
ctx
.
saved_tensors
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
3
]:
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
3
]:
g_rev
=
gidx
.
reverse
()
g_rev
=
_
reverse
(
gidx
)
if
reduce_op
==
'sum'
:
if
reduce_op
==
'sum'
:
if
op
in
[
'mul'
,
'div'
]:
if
op
in
[
'mul'
,
'div'
]:
dX
=
gspmm
(
g_rev
,
'mul'
,
'sum'
,
dZ
,
_muldiv
(
op
,
Y
))
dX
=
gspmm
(
g_rev
,
'mul'
,
'sum'
,
dZ
,
_muldiv
(
op
,
Y
))
...
@@ -132,7 +160,7 @@ class GSDDMM(th.autograd.Function):
...
@@ -132,7 +160,7 @@ class GSDDMM(th.autograd.Function):
X
,
Y
=
ctx
.
saved_tensors
X
,
Y
=
ctx
.
saved_tensors
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
2
]:
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
2
]:
if
lhs_target
in
[
'u'
,
'v'
]:
if
lhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
lhs_target
==
'v'
else
gidx
.
reverse
()
_gidx
=
gidx
if
lhs_target
==
'v'
else
_
reverse
(
gidx
)
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
else
:
# mul, div, dot
else
:
# mul, div, dot
...
@@ -152,7 +180,7 @@ class GSDDMM(th.autograd.Function):
...
@@ -152,7 +180,7 @@ class GSDDMM(th.autograd.Function):
dX
=
None
dX
=
None
if
op
!=
'copy_lhs'
and
ctx
.
needs_input_grad
[
3
]:
if
op
!=
'copy_lhs'
and
ctx
.
needs_input_grad
[
3
]:
if
rhs_target
in
[
'u'
,
'v'
]:
if
rhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
rhs_target
==
'v'
else
gidx
.
reverse
()
_gidx
=
gidx
if
rhs_target
==
'v'
else
_
reverse
(
gidx
)
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
_addsub
(
op
,
dZ
))
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
_addsub
(
op
,
dZ
))
else
:
# mul, div, dot
else
:
# mul, div, dot
...
@@ -198,7 +226,7 @@ class EdgeSoftmax(th.autograd.Function):
...
@@ -198,7 +226,7 @@ class EdgeSoftmax(th.autograd.Function):
if
not
is_all
(
eids
):
if
not
is_all
(
eids
):
gidx
=
gidx
.
edge_subgraph
([
eids
],
True
).
graph
gidx
=
gidx
.
edge_subgraph
([
eids
],
True
).
graph
if
norm_by
==
'src'
:
if
norm_by
==
'src'
:
gidx
=
gidx
.
reverse
()
gidx
=
_
reverse
(
gidx
)
score_max
=
_gspmm
(
gidx
,
'copy_rhs'
,
'max'
,
None
,
score
)[
0
]
score_max
=
_gspmm
(
gidx
,
'copy_rhs'
,
'max'
,
None
,
score
)[
0
]
score
=
th
.
exp
(
_gsddmm
(
gidx
,
'sub'
,
score
,
score_max
,
'e'
,
'v'
))
score
=
th
.
exp
(
_gsddmm
(
gidx
,
'sub'
,
score
,
score_max
,
'e'
,
'v'
))
score_sum
=
_gspmm
(
gidx
,
'copy_rhs'
,
'sum'
,
None
,
score
)[
0
]
score_sum
=
_gspmm
(
gidx
,
'copy_rhs'
,
'sum'
,
None
,
score
)[
0
]
...
...
python/dgl/heterograph.py
View file @
aa884d43
...
@@ -1291,13 +1291,16 @@ class DGLHeteroGraph(object):
...
@@ -1291,13 +1291,16 @@ class DGLHeteroGraph(object):
>>> import torch
>>> import torch
Create a homogeneous graph.
Create a homogeneous graph.
>>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3]))
>>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3]))
Manually set batch information
Manually set batch information
>>> g.set_batch_num_nodes(torch.tensor([3, 3])
>>> g.set_batch_num_nodes(torch.tensor([3, 3])
>>> g.set_batch_num_edges(torch.tensor([3, 3])
>>> g.set_batch_num_edges(torch.tensor([3, 3])
Unbatch the graph.
Unbatch the graph.
>>> dgl.unbatch(g)
>>> dgl.unbatch(g)
[Graph(num_nodes=3, num_edges=3,
[Graph(num_nodes=3, num_edges=3,
ndata_schemes={}
ndata_schemes={}
...
@@ -1433,13 +1436,16 @@ class DGLHeteroGraph(object):
...
@@ -1433,13 +1436,16 @@ class DGLHeteroGraph(object):
>>> import torch
>>> import torch
Create a homogeneous graph.
Create a homogeneous graph.
>>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3]))
>>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3]))
Manually set batch information
Manually set batch information
>>> g.set_batch_num_nodes(torch.tensor([3, 3])
>>> g.set_batch_num_nodes(torch.tensor([3, 3])
>>> g.set_batch_num_edges(torch.tensor([3, 3])
>>> g.set_batch_num_edges(torch.tensor([3, 3])
Unbatch the graph.
Unbatch the graph.
>>> dgl.unbatch(g)
>>> dgl.unbatch(g)
[Graph(num_nodes=3, num_edges=3,
[Graph(num_nodes=3, num_edges=3,
ndata_schemes={}
ndata_schemes={}
...
...
python/dgl/transform.py
View file @
aa884d43
...
@@ -644,6 +644,9 @@ def reverse(g, copy_ndata=True, copy_edata=False, *, share_ndata=None, share_eda
...
@@ -644,6 +644,9 @@ def reverse(g, copy_ndata=True, copy_edata=False, *, share_ndata=None, share_eda
:math:`(i_1, j_1), (i_2, j_2), \cdots` of type ``(U, E, V)`` is a new graph with edges
:math:`(i_1, j_1), (i_2, j_2), \cdots` of type ``(U, E, V)`` is a new graph with edges
:math:`(j_1, i_1), (j_2, i_2), \cdots` of type ``(V, E, U)``.
:math:`(j_1, i_1), (j_2, i_2), \cdots` of type ``(V, E, U)``.
The returned graph shares the data structure with the original graph, i.e. dgl.reverse
will not create extra storage for the reversed graph.
Parameters
Parameters
----------
----------
g : DGLGraph
g : DGLGraph
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment