Unverified Commit a613ad88 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

Refactor code for retaining formats in message-passing. (#2570)

parent 0f9056ed
import mxnet as mx
import numpy as np
from mxnet import nd
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _reverse
from ...base import dgl_warning, is_all, ALL
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx
......@@ -132,7 +132,7 @@ class GSpMM(mx.autograd.Function):
X, Y, argX, argY = self.saved_tensors
gidx, op, reduce_op = self.gidx, self.op, self.reduce_op
if op != 'copy_rhs':
g_rev = gidx.reverse()
g_rev = _reverse(gidx)
if reduce_op == 'sum':
if op in ['mul', 'div']:
dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0]
......@@ -215,7 +215,7 @@ class GSDDMM(mx.autograd.Function):
lhs_target, rhs_target = self.lhs_target, self.rhs_target
if op != 'copy_rhs':
if lhs_target in ['u', 'v']:
_gidx = gidx if self.lhs_target == 'v' else gidx.reverse()
_gidx = gidx if self.lhs_target == 'v' else _reverse(gidx)
if op in ['add', 'sub', 'copy_lhs']:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0]
else: # mul, div, dot
......@@ -235,7 +235,7 @@ class GSDDMM(mx.autograd.Function):
dX = nd.zeros_like(X)
if op != 'copy_lhs':
if self.rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse()
_gidx = gidx if rhs_target == 'v' else _reverse(gidx)
if op in ['add', 'sub', 'copy_rhs']:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0]
else: # mul, div, dot
......@@ -277,7 +277,7 @@ class EdgeSoftmax(mx.autograd.Function):
if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src':
gidx = gidx.reverse()
gidx = _reverse(gidx)
self.gidx = gidx
def forward(self, score):
......
import torch as th
from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _reverse
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce']
_inverse_format = {
'coo': 'coo',
'csr': 'csc',
'csc': 'csr'
}
def _reverse(gidx):
"""Reverse the given graph index while retaining its formats.
Parameters
----------
gidx: HeteroGraphIndex
Return
------
HeteroGraphIndex
"""
g_rev = gidx.reverse()
original_formats_dict = gidx.formats()
original_formats = original_formats_dict['created'] +\
original_formats_dict['not created']
g_rev = g_rev.formats([_inverse_format[fmt] for fmt in original_formats])
return g_rev
def _reduce_grad(grad, shape):
"""Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on
......
......@@ -2,7 +2,7 @@ import tensorflow as tf
import numpy as np
from .tensor import tensor, copy_to, context, asnumpy, zerocopy_from_numpy
from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _reverse
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce']
......@@ -110,7 +110,7 @@ def gspmm_real(gidx, op, reduce_op, X, Y):
def grad(dZ):
dZ = tensor(dZ)
if op != 'copy_rhs':
g_rev = gidx.reverse()
g_rev = _reverse(gidx)
if reduce_op == 'sum':
if op in ['mul', 'div']:
dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0]
......@@ -172,7 +172,7 @@ def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target):
def grad(dZ):
if op != 'copy_rhs':
if lhs_target in ['u', 'v']:
_gidx = gidx if lhs_target == 'v' else gidx.reverse()
_gidx = gidx if lhs_target == 'v' else _reverse(gidx)
if op in ['add', 'sub', 'copy_lhs']:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0]
else: # mul, div, dot
......@@ -192,7 +192,7 @@ def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target):
dX = tf.zeros_like(X)
if op != 'copy_lhs':
if rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse()
_gidx = gidx if rhs_target == 'v' else _reverse(gidx)
if op in ['add', 'sub', 'copy_rhs']:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0]
else: # mul, div, dot
......@@ -233,7 +233,7 @@ def edge_softmax_real(gidx, score, eids=ALL, norm_by='dst'):
if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src':
gidx = gidx.reverse()
gidx = _reverse(gidx)
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = tf.math.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0]
......
......@@ -5449,6 +5449,13 @@ class DGLHeteroGraph(object):
>>> # Only allowed formats will be displayed in the status query
>>> csr_g.formats()
{'created': ['csr'], 'not created': []}
Notes
-----
DGL will create sparse formats (only constrained to the allowed formats, i.e.
created formats and not created formats) on-the-fly during the training of Graph
Neural Networks. Once a format was created, it would be cached and reused until
user changes the graph structure.
"""
if formats is None:
# Return the format information
......
......@@ -7,6 +7,8 @@ from ._ffi.function import _init_api
from .base import DGLError
from . import backend as F
__all__ = ['_gspmm', '_gsddmm', '_segment_reduce', '_bwd_segment_cmp', '_reverse']
def infer_broadcast_shape(op, shp1, shp2):
r"""Check the shape validity, and infer the output shape given input shape and operator.
......@@ -65,6 +67,33 @@ def to_dgl_nd_for_write(x):
return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray_for_write(x)
inverse_format = {
'coo': 'coo',
'csr': 'csc',
'csc': 'csr'
}
def _reverse(gidx):
"""Reverse the given graph index while retaining its formats.
``dgl.reverse`` would not keep graph format information by default.
Parameters
----------
gidx: HeteroGraphIndex
Return
------
HeteroGraphIndex
"""
g_rev = gidx.reverse()
original_formats_dict = gidx.formats()
original_formats = original_formats_dict['created'] +\
original_formats_dict['not created']
g_rev = g_rev.formats([inverse_format[fmt] for fmt in original_formats])
return g_rev
target_mapping = {
'u': 0,
'e': 1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment