Unverified Commit aa884d43 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[doc][fix] Improve the docstring and fix its behavior in DGL's kernel (#2563)

* upd

* fix

* lint

* fix

* upd
parent a6abffe3
......@@ -5,6 +5,34 @@ from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce']
_inverse_format = {
'coo': 'coo',
'csr': 'csc',
'csc': 'csr'
}
def _reverse(gidx):
"""Reverse the given graph index while retaining its formats.
Parameters
----------
gidx: HeteroGraphIndex
Return
------
HeteroGraphIndex
"""
g_rev = gidx.reverse()
original_formats_dict = gidx.formats()
original_formats = original_formats_dict['created'] +\
original_formats_dict['not created']
g_rev = g_rev.formats([_inverse_format[fmt] for fmt in original_formats])
return g_rev
def _reduce_grad(grad, shape):
"""Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on
......@@ -71,7 +99,7 @@ class GSpMM(th.autograd.Function):
gidx, op, reduce_op = ctx.backward_cache
X, Y, argX, argY = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[3]:
g_rev = gidx.reverse()
g_rev = _reverse(gidx)
if reduce_op == 'sum':
if op in ['mul', 'div']:
dX = gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))
......@@ -132,7 +160,7 @@ class GSDDMM(th.autograd.Function):
X, Y = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[2]:
if lhs_target in ['u', 'v']:
_gidx = gidx if lhs_target == 'v' else gidx.reverse()
_gidx = gidx if lhs_target == 'v' else _reverse(gidx)
if op in ['add', 'sub', 'copy_lhs']:
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)
else: # mul, div, dot
......@@ -152,7 +180,7 @@ class GSDDMM(th.autograd.Function):
dX = None
if op != 'copy_lhs' and ctx.needs_input_grad[3]:
if rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse()
_gidx = gidx if rhs_target == 'v' else _reverse(gidx)
if op in ['add', 'sub', 'copy_rhs']:
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))
else: # mul, div, dot
......@@ -198,7 +226,7 @@ class EdgeSoftmax(th.autograd.Function):
if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src':
gidx = gidx.reverse()
gidx = _reverse(gidx)
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = th.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0]
......
......@@ -1291,13 +1291,16 @@ class DGLHeteroGraph(object):
>>> import torch
Create a homogeneous graph.
>>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3]))
Manually set batch information
>>> g.set_batch_num_nodes(torch.tensor([3, 3])
>>> g.set_batch_num_edges(torch.tensor([3, 3])
Unbatch the graph.
>>> dgl.unbatch(g)
[Graph(num_nodes=3, num_edges=3,
ndata_schemes={}
......@@ -1433,13 +1436,16 @@ class DGLHeteroGraph(object):
>>> import torch
Create a homogeneous graph.
>>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3]))
Manually set batch information
>>> g.set_batch_num_nodes(torch.tensor([3, 3])
>>> g.set_batch_num_edges(torch.tensor([3, 3])
Unbatch the graph.
>>> dgl.unbatch(g)
[Graph(num_nodes=3, num_edges=3,
ndata_schemes={}
......
......@@ -644,6 +644,9 @@ def reverse(g, copy_ndata=True, copy_edata=False, *, share_ndata=None, share_eda
:math:`(i_1, j_1), (i_2, j_2), \cdots` of type ``(U, E, V)`` is a new graph with edges
:math:`(j_1, i_1), (j_2, i_2), \cdots` of type ``(V, E, U)``.
The returned graph shares the data structure with the original graph, i.e. dgl.reverse
will not create extra storage for the reversed graph.
Parameters
----------
g : DGLGraph
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment