Unverified Commit aa884d43 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[doc][fix] Improve the docstring and fix its behavior in DGL's kernel (#2563)

* upd

* fix

* lint

* fix

* upd
parent a6abffe3
...@@ -5,6 +5,34 @@ from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp ...@@ -5,6 +5,34 @@ from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce'] __all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce']
_inverse_format = {
'coo': 'coo',
'csr': 'csc',
'csc': 'csr'
}
def _reverse(gidx):
"""Reverse the given graph index while retaining its formats.
Parameters
----------
gidx: HeteroGraphIndex
Return
------
HeteroGraphIndex
"""
g_rev = gidx.reverse()
original_formats_dict = gidx.formats()
original_formats = original_formats_dict['created'] +\
original_formats_dict['not created']
g_rev = g_rev.formats([_inverse_format[fmt] for fmt in original_formats])
return g_rev
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
"""Reduce gradient on the broadcast dimension """Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on If there is broadcast in forward pass, gradients need to be reduced on
...@@ -71,7 +99,7 @@ class GSpMM(th.autograd.Function): ...@@ -71,7 +99,7 @@ class GSpMM(th.autograd.Function):
gidx, op, reduce_op = ctx.backward_cache gidx, op, reduce_op = ctx.backward_cache
X, Y, argX, argY = ctx.saved_tensors X, Y, argX, argY = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[3]: if op != 'copy_rhs' and ctx.needs_input_grad[3]:
g_rev = gidx.reverse() g_rev = _reverse(gidx)
if reduce_op == 'sum': if reduce_op == 'sum':
if op in ['mul', 'div']: if op in ['mul', 'div']:
dX = gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y)) dX = gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))
...@@ -132,7 +160,7 @@ class GSDDMM(th.autograd.Function): ...@@ -132,7 +160,7 @@ class GSDDMM(th.autograd.Function):
X, Y = ctx.saved_tensors X, Y = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[2]: if op != 'copy_rhs' and ctx.needs_input_grad[2]:
if lhs_target in ['u', 'v']: if lhs_target in ['u', 'v']:
_gidx = gidx if lhs_target == 'v' else gidx.reverse() _gidx = gidx if lhs_target == 'v' else _reverse(gidx)
if op in ['add', 'sub', 'copy_lhs']: if op in ['add', 'sub', 'copy_lhs']:
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)
else: # mul, div, dot else: # mul, div, dot
...@@ -152,7 +180,7 @@ class GSDDMM(th.autograd.Function): ...@@ -152,7 +180,7 @@ class GSDDMM(th.autograd.Function):
dX = None dX = None
if op != 'copy_lhs' and ctx.needs_input_grad[3]: if op != 'copy_lhs' and ctx.needs_input_grad[3]:
if rhs_target in ['u', 'v']: if rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse() _gidx = gidx if rhs_target == 'v' else _reverse(gidx)
if op in ['add', 'sub', 'copy_rhs']: if op in ['add', 'sub', 'copy_rhs']:
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ)) dY = gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))
else: # mul, div, dot else: # mul, div, dot
...@@ -198,7 +226,7 @@ class EdgeSoftmax(th.autograd.Function): ...@@ -198,7 +226,7 @@ class EdgeSoftmax(th.autograd.Function):
if not is_all(eids): if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src': if norm_by == 'src':
gidx = gidx.reverse() gidx = _reverse(gidx)
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0] score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = th.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v')) score = th.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0] score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0]
......
...@@ -1291,13 +1291,16 @@ class DGLHeteroGraph(object): ...@@ -1291,13 +1291,16 @@ class DGLHeteroGraph(object):
>>> import torch >>> import torch
Create a homogeneous graph. Create a homogeneous graph.
>>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3])) >>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3]))
Manually set batch information Manually set batch information
>>> g.set_batch_num_nodes(torch.tensor([3, 3]) >>> g.set_batch_num_nodes(torch.tensor([3, 3])
>>> g.set_batch_num_edges(torch.tensor([3, 3]) >>> g.set_batch_num_edges(torch.tensor([3, 3])
Unbatch the graph. Unbatch the graph.
>>> dgl.unbatch(g) >>> dgl.unbatch(g)
[Graph(num_nodes=3, num_edges=3, [Graph(num_nodes=3, num_edges=3,
ndata_schemes={} ndata_schemes={}
...@@ -1433,13 +1436,16 @@ class DGLHeteroGraph(object): ...@@ -1433,13 +1436,16 @@ class DGLHeteroGraph(object):
>>> import torch >>> import torch
Create a homogeneous graph. Create a homogeneous graph.
>>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3])) >>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3]))
Manually set batch information Manually set batch information
>>> g.set_batch_num_nodes(torch.tensor([3, 3]) >>> g.set_batch_num_nodes(torch.tensor([3, 3])
>>> g.set_batch_num_edges(torch.tensor([3, 3]) >>> g.set_batch_num_edges(torch.tensor([3, 3])
Unbatch the graph. Unbatch the graph.
>>> dgl.unbatch(g) >>> dgl.unbatch(g)
[Graph(num_nodes=3, num_edges=3, [Graph(num_nodes=3, num_edges=3,
ndata_schemes={} ndata_schemes={}
......
...@@ -644,6 +644,9 @@ def reverse(g, copy_ndata=True, copy_edata=False, *, share_ndata=None, share_eda ...@@ -644,6 +644,9 @@ def reverse(g, copy_ndata=True, copy_edata=False, *, share_ndata=None, share_eda
:math:`(i_1, j_1), (i_2, j_2), \cdots` of type ``(U, E, V)`` is a new graph with edges :math:`(i_1, j_1), (i_2, j_2), \cdots` of type ``(U, E, V)`` is a new graph with edges
:math:`(j_1, i_1), (j_2, i_2), \cdots` of type ``(V, E, U)``. :math:`(j_1, i_1), (j_2, i_2), \cdots` of type ``(V, E, U)``.
The returned graph shares the data structure with the original graph, i.e. dgl.reverse
will not create extra storage for the reversed graph.
Parameters Parameters
---------- ----------
g : DGLGraph g : DGLGraph
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment