Unverified Commit a613ad88 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

Refactor code for retaining formats in message-passing. (#2570)

parent 0f9056ed
import mxnet as mx import mxnet as mx
import numpy as np import numpy as np
from mxnet import nd from mxnet import nd
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _reverse
from ...base import dgl_warning, is_all, ALL from ...base import dgl_warning, is_all, ALL
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx
...@@ -132,7 +132,7 @@ class GSpMM(mx.autograd.Function): ...@@ -132,7 +132,7 @@ class GSpMM(mx.autograd.Function):
X, Y, argX, argY = self.saved_tensors X, Y, argX, argY = self.saved_tensors
gidx, op, reduce_op = self.gidx, self.op, self.reduce_op gidx, op, reduce_op = self.gidx, self.op, self.reduce_op
if op != 'copy_rhs': if op != 'copy_rhs':
g_rev = gidx.reverse() g_rev = _reverse(gidx)
if reduce_op == 'sum': if reduce_op == 'sum':
if op in ['mul', 'div']: if op in ['mul', 'div']:
dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0] dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0]
...@@ -215,7 +215,7 @@ class GSDDMM(mx.autograd.Function): ...@@ -215,7 +215,7 @@ class GSDDMM(mx.autograd.Function):
lhs_target, rhs_target = self.lhs_target, self.rhs_target lhs_target, rhs_target = self.lhs_target, self.rhs_target
if op != 'copy_rhs': if op != 'copy_rhs':
if lhs_target in ['u', 'v']: if lhs_target in ['u', 'v']:
_gidx = gidx if self.lhs_target == 'v' else gidx.reverse() _gidx = gidx if self.lhs_target == 'v' else _reverse(gidx)
if op in ['add', 'sub', 'copy_lhs']: if op in ['add', 'sub', 'copy_lhs']:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0]
else: # mul, div, dot else: # mul, div, dot
...@@ -235,7 +235,7 @@ class GSDDMM(mx.autograd.Function): ...@@ -235,7 +235,7 @@ class GSDDMM(mx.autograd.Function):
dX = nd.zeros_like(X) dX = nd.zeros_like(X)
if op != 'copy_lhs': if op != 'copy_lhs':
if self.rhs_target in ['u', 'v']: if self.rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse() _gidx = gidx if rhs_target == 'v' else _reverse(gidx)
if op in ['add', 'sub', 'copy_rhs']: if op in ['add', 'sub', 'copy_rhs']:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0] dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0]
else: # mul, div, dot else: # mul, div, dot
...@@ -277,7 +277,7 @@ class EdgeSoftmax(mx.autograd.Function): ...@@ -277,7 +277,7 @@ class EdgeSoftmax(mx.autograd.Function):
if not is_all(eids): if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src': if norm_by == 'src':
gidx = gidx.reverse() gidx = _reverse(gidx)
self.gidx = gidx self.gidx = gidx
def forward(self, score): def forward(self, score):
......
import torch as th import torch as th
from ...base import is_all, ALL from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _reverse
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce'] __all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce']
_inverse_format = {
'coo': 'coo',
'csr': 'csc',
'csc': 'csr'
}
def _reverse(gidx):
"""Reverse the given graph index while retaining its formats.
Parameters
----------
gidx: HeteroGraphIndex
Return
------
HeteroGraphIndex
"""
g_rev = gidx.reverse()
original_formats_dict = gidx.formats()
original_formats = original_formats_dict['created'] +\
original_formats_dict['not created']
g_rev = g_rev.formats([_inverse_format[fmt] for fmt in original_formats])
return g_rev
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
"""Reduce gradient on the broadcast dimension """Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on If there is broadcast in forward pass, gradients need to be reduced on
......
...@@ -2,7 +2,7 @@ import tensorflow as tf ...@@ -2,7 +2,7 @@ import tensorflow as tf
import numpy as np import numpy as np
from .tensor import tensor, copy_to, context, asnumpy, zerocopy_from_numpy from .tensor import tensor, copy_to, context, asnumpy, zerocopy_from_numpy
from ...base import is_all, ALL from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _reverse
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce'] __all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce']
...@@ -110,7 +110,7 @@ def gspmm_real(gidx, op, reduce_op, X, Y): ...@@ -110,7 +110,7 @@ def gspmm_real(gidx, op, reduce_op, X, Y):
def grad(dZ): def grad(dZ):
dZ = tensor(dZ) dZ = tensor(dZ)
if op != 'copy_rhs': if op != 'copy_rhs':
g_rev = gidx.reverse() g_rev = _reverse(gidx)
if reduce_op == 'sum': if reduce_op == 'sum':
if op in ['mul', 'div']: if op in ['mul', 'div']:
dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0] dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0]
...@@ -172,7 +172,7 @@ def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target): ...@@ -172,7 +172,7 @@ def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target):
def grad(dZ): def grad(dZ):
if op != 'copy_rhs': if op != 'copy_rhs':
if lhs_target in ['u', 'v']: if lhs_target in ['u', 'v']:
_gidx = gidx if lhs_target == 'v' else gidx.reverse() _gidx = gidx if lhs_target == 'v' else _reverse(gidx)
if op in ['add', 'sub', 'copy_lhs']: if op in ['add', 'sub', 'copy_lhs']:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0]
else: # mul, div, dot else: # mul, div, dot
...@@ -192,7 +192,7 @@ def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target): ...@@ -192,7 +192,7 @@ def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target):
dX = tf.zeros_like(X) dX = tf.zeros_like(X)
if op != 'copy_lhs': if op != 'copy_lhs':
if rhs_target in ['u', 'v']: if rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse() _gidx = gidx if rhs_target == 'v' else _reverse(gidx)
if op in ['add', 'sub', 'copy_rhs']: if op in ['add', 'sub', 'copy_rhs']:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0] dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0]
else: # mul, div, dot else: # mul, div, dot
...@@ -233,7 +233,7 @@ def edge_softmax_real(gidx, score, eids=ALL, norm_by='dst'): ...@@ -233,7 +233,7 @@ def edge_softmax_real(gidx, score, eids=ALL, norm_by='dst'):
if not is_all(eids): if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src': if norm_by == 'src':
gidx = gidx.reverse() gidx = _reverse(gidx)
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0] score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = tf.math.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v')) score = tf.math.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0] score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0]
......
...@@ -5449,6 +5449,13 @@ class DGLHeteroGraph(object): ...@@ -5449,6 +5449,13 @@ class DGLHeteroGraph(object):
>>> # Only allowed formats will be displayed in the status query >>> # Only allowed formats will be displayed in the status query
>>> csr_g.formats() >>> csr_g.formats()
{'created': ['csr'], 'not created': []} {'created': ['csr'], 'not created': []}
Notes
-----
DGL will create sparse formats (only constrained to the allowed formats, i.e.
created formats and not created formats) on-the-fly during the training of Graph
Neural Networks. Once a format was created, it would be cached and reused until
user changes the graph structure.
""" """
if formats is None: if formats is None:
# Return the format information # Return the format information
......
...@@ -7,6 +7,8 @@ from ._ffi.function import _init_api ...@@ -7,6 +7,8 @@ from ._ffi.function import _init_api
from .base import DGLError from .base import DGLError
from . import backend as F from . import backend as F
__all__ = ['_gspmm', '_gsddmm', '_segment_reduce', '_bwd_segment_cmp', '_reverse']
def infer_broadcast_shape(op, shp1, shp2): def infer_broadcast_shape(op, shp1, shp2):
r"""Check the shape validity, and infer the output shape given input shape and operator. r"""Check the shape validity, and infer the output shape given input shape and operator.
...@@ -65,6 +67,33 @@ def to_dgl_nd_for_write(x): ...@@ -65,6 +67,33 @@ def to_dgl_nd_for_write(x):
return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray_for_write(x) return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray_for_write(x)
inverse_format = {
'coo': 'coo',
'csr': 'csc',
'csc': 'csr'
}
def _reverse(gidx):
"""Reverse the given graph index while retaining its formats.
``dgl.reverse`` would not keep graph format information by default.
Parameters
----------
gidx: HeteroGraphIndex
Return
------
HeteroGraphIndex
"""
g_rev = gidx.reverse()
original_formats_dict = gidx.formats()
original_formats = original_formats_dict['created'] +\
original_formats_dict['not created']
g_rev = g_rev.formats([inverse_format[fmt] for fmt in original_formats])
return g_rev
target_mapping = { target_mapping = {
'u': 0, 'u': 0,
'e': 1, 'e': 1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment