Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
aa884d43
Unverified
Commit
aa884d43
authored
Jan 25, 2021
by
Zihao Ye
Committed by
GitHub
Jan 25, 2021
Browse files
[doc][fix] Improve the docstring and fix its behavior in DGL's kernel (#2563)
* upd * fix * lint * fix * upd
parent
a6abffe3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
4 deletions
+41
-4
python/dgl/backend/pytorch/sparse.py
python/dgl/backend/pytorch/sparse.py
+32
-4
python/dgl/heterograph.py
python/dgl/heterograph.py
+6
-0
python/dgl/transform.py
python/dgl/transform.py
+3
-0
No files found.
python/dgl/backend/pytorch/sparse.py
View file @
aa884d43
...
...
@@ -5,6 +5,34 @@ from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp
__all__
=
[
'gspmm'
,
'gsddmm'
,
'edge_softmax'
,
'segment_reduce'
]
_inverse_format
=
{
'coo'
:
'coo'
,
'csr'
:
'csc'
,
'csc'
:
'csr'
}
def
_reverse
(
gidx
):
"""Reverse the given graph index while retaining its formats.
Parameters
----------
gidx: HeteroGraphIndex
Return
------
HeteroGraphIndex
"""
g_rev
=
gidx
.
reverse
()
original_formats_dict
=
gidx
.
formats
()
original_formats
=
original_formats_dict
[
'created'
]
+
\
original_formats_dict
[
'not created'
]
g_rev
=
g_rev
.
formats
([
_inverse_format
[
fmt
]
for
fmt
in
original_formats
])
return
g_rev
def
_reduce_grad
(
grad
,
shape
):
"""Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on
...
...
@@ -71,7 +99,7 @@ class GSpMM(th.autograd.Function):
gidx
,
op
,
reduce_op
=
ctx
.
backward_cache
X
,
Y
,
argX
,
argY
=
ctx
.
saved_tensors
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
3
]:
g_rev
=
gidx
.
reverse
()
g_rev
=
_
reverse
(
gidx
)
if
reduce_op
==
'sum'
:
if
op
in
[
'mul'
,
'div'
]:
dX
=
gspmm
(
g_rev
,
'mul'
,
'sum'
,
dZ
,
_muldiv
(
op
,
Y
))
...
...
@@ -132,7 +160,7 @@ class GSDDMM(th.autograd.Function):
X
,
Y
=
ctx
.
saved_tensors
if
op
!=
'copy_rhs'
and
ctx
.
needs_input_grad
[
2
]:
if
lhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
lhs_target
==
'v'
else
gidx
.
reverse
()
_gidx
=
gidx
if
lhs_target
==
'v'
else
_
reverse
(
gidx
)
if
op
in
[
'add'
,
'sub'
,
'copy_lhs'
]:
dX
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
dZ
)
else
:
# mul, div, dot
...
...
@@ -152,7 +180,7 @@ class GSDDMM(th.autograd.Function):
dX
=
None
if
op
!=
'copy_lhs'
and
ctx
.
needs_input_grad
[
3
]:
if
rhs_target
in
[
'u'
,
'v'
]:
_gidx
=
gidx
if
rhs_target
==
'v'
else
gidx
.
reverse
()
_gidx
=
gidx
if
rhs_target
==
'v'
else
_
reverse
(
gidx
)
if
op
in
[
'add'
,
'sub'
,
'copy_rhs'
]:
dY
=
gspmm
(
_gidx
,
'copy_rhs'
,
'sum'
,
None
,
_addsub
(
op
,
dZ
))
else
:
# mul, div, dot
...
...
@@ -198,7 +226,7 @@ class EdgeSoftmax(th.autograd.Function):
if
not
is_all
(
eids
):
gidx
=
gidx
.
edge_subgraph
([
eids
],
True
).
graph
if
norm_by
==
'src'
:
gidx
=
gidx
.
reverse
()
gidx
=
_
reverse
(
gidx
)
score_max
=
_gspmm
(
gidx
,
'copy_rhs'
,
'max'
,
None
,
score
)[
0
]
score
=
th
.
exp
(
_gsddmm
(
gidx
,
'sub'
,
score
,
score_max
,
'e'
,
'v'
))
score_sum
=
_gspmm
(
gidx
,
'copy_rhs'
,
'sum'
,
None
,
score
)[
0
]
...
...
python/dgl/heterograph.py
View file @
aa884d43
...
...
@@ -1291,13 +1291,16 @@ class DGLHeteroGraph(object):
>>> import torch
Create a homogeneous graph.
>>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3]))
Manually set batch information
>>> g.set_batch_num_nodes(torch.tensor([3, 3])
>>> g.set_batch_num_edges(torch.tensor([3, 3])
Unbatch the graph.
>>> dgl.unbatch(g)
[Graph(num_nodes=3, num_edges=3,
ndata_schemes={}
...
...
@@ -1433,13 +1436,16 @@ class DGLHeteroGraph(object):
>>> import torch
Create a homogeneous graph.
>>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3]))
Manually set batch information
>>> g.set_batch_num_nodes(torch.tensor([3, 3])
>>> g.set_batch_num_edges(torch.tensor([3, 3])
Unbatch the graph.
>>> dgl.unbatch(g)
[Graph(num_nodes=3, num_edges=3,
ndata_schemes={}
...
...
python/dgl/transform.py
View file @
aa884d43
...
...
@@ -644,6 +644,9 @@ def reverse(g, copy_ndata=True, copy_edata=False, *, share_ndata=None, share_eda
:math:`(i_1, j_1), (i_2, j_2), \cdots` of type ``(U, E, V)`` is a new graph with edges
:math:`(j_1, i_1), (j_2, i_2), \cdots` of type ``(V, E, U)``.
The returned graph shares the data structure with the original graph, i.e. dgl.reverse
will not create extra storage for the reversed graph.
Parameters
----------
g : DGLGraph
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment