Unverified Commit cb0e1103 authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Feature] Add Min/max reducer in heterogeneous API for unary message functions (#3514)



* min/max support for forward CPU heterograph

* Added etype with each argU values

* scatter_add needs fix

* added scatter_add_hetero. Grads dont match for max reducer

* storing ntype in argX

* fixing scatter_add_hetero

* hetero matches with torch's scatter add

* works copy_e forward+cpu

* added backward for copy_rhs

* Computes gradient for all node types in one kernel

* bug fix

* unnitest for max/min on CPU

* renamed scatter_add_hetero to update_grad_minmax_hetero

* lint check and comment out cuda call for max. Code is for CPU only

* lint check

* replace inf with zero

* minor

* lint check

* removed LIBXSMM code from hetro code

* fixing backward operator of UpdateGradMinMaxHetero

* removed backward from update_grad_minmax_hetero

* docstring

* improved docstring and coding style

* Added pass by pointer for output

* typos and pass by references

* Support for copy_rhs

* Added header <string>

* fix bug in copy_u_max

* Added comments and dimension check of all etypes

* skip mxnet check

* pass by pointer output arrays

* updated docstring
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 769718df
import torch as th import torch as th
from distutils.version import LooseVersion from distutils.version import LooseVersion
from ...base import is_all, ALL from ...base import is_all, ALL
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _scatter_add from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp
from ...sparse import _csrmm, _csrsum, _csrmask from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero
from ...heterograph_index import create_unitgraph_from_csr from ...heterograph_index import create_unitgraph_from_csr
if LooseVersion(th.__version__) >= LooseVersion("1.6.0"): if LooseVersion(th.__version__) >= LooseVersion("1.6.0"):
...@@ -194,10 +194,9 @@ class GSpMM_hetero(th.autograd.Function): ...@@ -194,10 +194,9 @@ class GSpMM_hetero(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16) @custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, op, reduce_op, X_len, *feats): # feats = lhs_data + rhs_data def forward(ctx, gidx, op, reduce_op, X_len, *feats): # feats = lhs_data + rhs_data
out, (argX, argY) = _gspmm_hetero(gidx, op, reduce_op, X_len, feats) out, (argX, argY, argX_ntype, argY_etype) = _gspmm_hetero(gidx, op, reduce_op, X_len, feats)
X, Y = feats[:X_len], feats[X_len:] X, Y = feats[:X_len], feats[X_len:]
# TODO (Israt): check target to decide src_id/dst_id? # TODO (Israt): check target to decide src_id/dst_id?
# checking the first relation to decide for all the relations
src_id, dst_id = gidx.metagraph.find_edge(0) src_id, dst_id = gidx.metagraph.find_edge(0)
reduce_last = _need_reduce_last_dim(X[src_id], Y[dst_id]) reduce_last = _need_reduce_last_dim(X[src_id], Y[dst_id])
X_shape = tuple([X[i].shape if X[i] is not None else None X_shape = tuple([X[i].shape if X[i] is not None else None
...@@ -214,11 +213,11 @@ class GSpMM_hetero(th.autograd.Function): ...@@ -214,11 +213,11 @@ class GSpMM_hetero(th.autograd.Function):
# checking the first relation to decide for all the relations # checking the first relation to decide for all the relations
if not spmm_cache_argX(op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]): if not spmm_cache_argX(op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]):
argX = None argX = tuple([None] * len(X))
if not spmm_cache_argY(op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]): if not spmm_cache_argY(op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]):
argY = None argY = tuple([None] * len(X))
ctx.save_for_backward(*feats, argX, argY) ctx.save_for_backward(*feats, *argX, *argX_ntype, *argY, *argY_etype )
return out return out
@staticmethod @staticmethod
...@@ -226,14 +225,16 @@ class GSpMM_hetero(th.autograd.Function): ...@@ -226,14 +225,16 @@ class GSpMM_hetero(th.autograd.Function):
def backward(ctx, *dZ): def backward(ctx, *dZ):
gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len = ctx.backward_cache gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len = ctx.backward_cache
ctx.backward_cache = None ctx.backward_cache = None
feats = ctx.saved_tensors[:-2] num_ntypes = gidx.number_of_ntypes()
argX = ctx.saved_tensors[-2] feats = ctx.saved_tensors[:-(4 * num_ntypes)]
argY = ctx.saved_tensors[-1] argX = ctx.saved_tensors[-(4 * num_ntypes):-(3 * num_ntypes)]
argX_ntype = ctx.saved_tensors[-(3 * num_ntypes):-(2 * num_ntypes)]
argY = ctx.saved_tensors[-(2 * num_ntypes):- num_ntypes]
argY_etype = ctx.saved_tensors[-num_ntypes:]
X, Y = feats[:X_len], feats[X_len:] X, Y = feats[:X_len], feats[X_len:]
if op != 'copy_rhs' and any([x is not None for x in X]): if op != 'copy_rhs' and any([x is not None for x in X]):
g_rev = gidx.reverse() g_rev = gidx.reverse()
# TODO(Israt): implement other combinations of message and reduce functions
if reduce_op == 'sum': if reduce_op == 'sum':
if op == 'mul': if op == 'mul':
dX = gspmm_hetero(g_rev, 'mul', 'sum', len(X), *tuple(dZ + Y)) dX = gspmm_hetero(g_rev, 'mul', 'sum', len(X), *tuple(dZ + Y))
...@@ -242,6 +243,17 @@ class GSpMM_hetero(th.autograd.Function): ...@@ -242,6 +243,17 @@ class GSpMM_hetero(th.autograd.Function):
elif op == 'copy_lhs': elif op == 'copy_lhs':
tpl_None = tuple([None] * len(Y)) tpl_None = tuple([None] * len(Y))
dX = gspmm_hetero(g_rev, 'copy_lhs', 'sum', len(X), *tuple(dZ + tpl_None)) dX = gspmm_hetero(g_rev, 'copy_lhs', 'sum', len(X), *tuple(dZ + tpl_None))
else: # max/min
# Assuming that the features are of the same dimension (enforced by the forward function)
src_id, dst_id = gidx.metagraph.find_edge(0)
dX = tuple([th.zeros((X_shape[i][0],) + dZ[dst_id].shape[1:], dtype=dtype, device=device)
if X[i] is not None else None for i in range(len(X))])
if op == 'mul':
grad = _expand(Y, dZ.shape[1:]).gather(
0, argY.long()) * dZ
dX.scatter_add_(0, argX.long(), grad)
elif op in ['add', 'copy_lhs']:
dX = _update_grad_minmax_hetero(g_rev, op, dZ, argX, argX_ntype, dX)
dX = tuple([_reduce_grad(dX[i], X_shape[i]) if X[i] is not None else None dX = tuple([_reduce_grad(dX[i], X_shape[i]) if X[i] is not None else None
for i in range(len(X))]) for i in range(len(X))])
else: # X has not gradient else: # X has not gradient
...@@ -258,8 +270,18 @@ class GSpMM_hetero(th.autograd.Function): ...@@ -258,8 +270,18 @@ class GSpMM_hetero(th.autograd.Function):
dY = gsddmm_hetero(gidx, 'mul', X_len, 'u', 'v', *tpl_X_dZ) dY = gsddmm_hetero(gidx, 'mul', X_len, 'u', 'v', *tpl_X_dZ)
elif op in ['add', 'copy_rhs']: elif op in ['add', 'copy_rhs']:
dY = gsddmm_hetero(gidx, 'copy_rhs', X_len, 'u', 'v', *tpl_X_dZ) dY = gsddmm_hetero(gidx, 'copy_rhs', X_len, 'u', 'v', *tpl_X_dZ)
dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if Y[i] is not None else None else: # max/min
for i in range(len(Y))]) src_id, dst_id = gidx.metagraph.find_edge(0)
dY = tuple([th.zeros((Y_shape[i][0],) + dZ[dst_id].shape[1:], dtype=dtype, device=device)
if Y[i] is not None else None for i in range(len(Y))])
if op == 'mul':
grad = _expand(X, dZ.shape[1:]).gather(
0, argX.long()) * dZ
dY.scatter_add_(0, argY.long(), grad)
elif op in ['add', 'copy_rhs']:
dY = _update_grad_minmax_hetero(gidx.reverse(), op, dZ, argY, argY_etype, dY)
dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if dY[i] is not None else None
for i in range(len(dY))])
else: # Y has no gradient else: # Y has no gradient
dY = tuple([None] * len(Y)) dY = tuple([None] * len(Y))
return (None, None, None, None) + dX + dY return (None, None, None, None) + dX + dY
...@@ -273,7 +295,7 @@ def sddmm_cache_X(op, req_grad_X, req_grad_Y): ...@@ -273,7 +295,7 @@ def sddmm_cache_X(op, req_grad_X, req_grad_Y):
def sddmm_cache_Y(op, req_grad_X, req_grad_Y): def sddmm_cache_Y(op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache Y in SDDMM forward stage.""" """Rules to identify whether to cache Y in SDDMM forward stage."""
if op in ['mul', 'dot'] and req_grad_X: if op in ['mul', 'dot'] and req_grad_X:
return True return True
return False return False
...@@ -424,6 +446,7 @@ class GSDDMM_hetero(th.autograd.Function): ...@@ -424,6 +446,7 @@ class GSDDMM_hetero(th.autograd.Function):
dY = tuple([None] * len(Y)) dY = tuple([None] * len(Y))
return (None, None, None, None, None) + dX + dY return (None, None, None, None, None) + dX + dY
class EdgeSoftmax(th.autograd.Function): class EdgeSoftmax(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16) @custom_fwd(cast_inputs=th.float16)
......
...@@ -4872,10 +4872,6 @@ class DGLHeteroGraph(object): ...@@ -4872,10 +4872,6 @@ class DGLHeteroGraph(object):
raise DGLError("User defined functions are not yet " raise DGLError("User defined functions are not yet "
"supported in update_all for heterogeneous graphs. " "supported in update_all for heterogeneous graphs. "
"Please use multi_update_all instead.") "Please use multi_update_all instead.")
if reduce_func.name in ['max', 'min']:
raise NotImplementedError("Reduce op \'" + reduce_func.name + "\' is not yet "
"supported in update_all for heterogeneous graphs. "
"Please use multi_update_all instead.")
if reduce_func.name in ['mean']: if reduce_func.name in ['mean']:
raise NotImplementedError("Cannot set both intra-type and inter-type reduce " raise NotImplementedError("Cannot set both intra-type and inter-type reduce "
"operators as 'mean' using update_all. Please use " "operators as 'mean' using update_all. Please use "
......
...@@ -87,7 +87,14 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data): ...@@ -87,7 +87,14 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
ret = gspmm_internal_hetero(g._graph, op, ret = gspmm_internal_hetero(g._graph, op,
'sum' if reduce_op == 'mean' else reduce_op, 'sum' if reduce_op == 'mean' else reduce_op,
len(lhs_data), *lhs_and_rhs_tuple) len(lhs_data), *lhs_and_rhs_tuple)
# TODO (Israt): Add support for 'max', 'min', 'mean' in heterograph # `update_all` on heterogeneous graphs replaces the inf values with zeros on
# the final output (after processing all etypes). `multi_update_all` performs
# this operation after processing each etype. It computes the final output based
# on the output of each etype where inf is already replaced by zero.
if reduce_op in ['min', 'max']:
ret = tuple([F.replace_inf_with_zero(ret[i]) if ret[i] is not None else None
for i in range(len(ret))])
# TODO (Israt): Add support for 'mean' in heterograph
# divide in degrees for mean reducer. # divide in degrees for mean reducer.
if reduce_op == 'mean': if reduce_op == 'mean':
ret_shape = F.shape(ret) ret_shape = F.shape(ret)
......
...@@ -190,7 +190,51 @@ def _gspmm(gidx, op, reduce_op, u, e): ...@@ -190,7 +190,51 @@ def _gspmm(gidx, op, reduce_op, u, e):
def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple): def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
r""" Generalized Sparse Matrix Multiplication interface. r""" Generalized Sparse Matrix Multiplication interface on heterogeneous graphs.
It handles multiple node and edge types of the graph. For each edge type, it takes
the result of :attr:`op` on source node feature and edge feature, and leads to a
message on edge. Then it aggregates the message by :attr:`reduce_op` on the destination
nodes of the etype.
.. math::
x_v = \psi_{(u, v, e)\in \mathcal{G}}(\rho(x_u, x_e))
where :math:`x_v` is the returned feature on destination nodes, and :math`x_u`,
:math:`x_e` refers to :attr:`u`, :attr:`e` respectively. :math:`\rho` means binary
operator :attr:`op` and :math:`\psi` means reduce operator :attr:`reduce_op`,
:math:`\mathcal{G}` is the graph we apply gspmm on: :attr:`g`.
Note that this function does not handle gradients.
Parameters
----------
gidx : HeteroGraphIndex
The input graph index.
op : str
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``, ``copy_lhs``,
``copy_rhs``.
reduce_op : str
Reduce operator, could be ``sum``, ``max``, ``min``.
u_len : int
The number of tensors in ``u`` (source node features)
u_and_e_tuple : Tuple of tensors
Tuple of source nodes' features and edges' features. ``u_and_e_tuple[:u_len]``
stores the source nodes's features of all source node types. ``u_and_e_tuple[u_len:]``
stores the edges's features of all the edge types.
The source nodes' features of the soruce node types could be None if op is ``copy_rhs``.
The edges' features of the edge types could be None if op is ``copy_lhs``.
Returns
-------
tuple
The returned tuple is composed of two elements:
- The first element refers to the tuple of result tensors.
- The second element refers to a tuple composed of arg_u and arg_e
(which is useful when reducer is `min`/`max`).
Notes
-----
This function does not handle gradients.
""" """
u_tuple, e_tuple = u_and_e_tuple[:u_len], u_and_e_tuple[u_len:] u_tuple, e_tuple = u_and_e_tuple[:u_len], u_and_e_tuple[u_len:]
use_u = op != 'copy_rhs' use_u = op != 'copy_rhs'
...@@ -199,11 +243,25 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple): ...@@ -199,11 +243,25 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
# deal with scalar features. # deal with scalar features.
expand_u, expand_e = False, False expand_u, expand_e = False, False
list_u = [None] * gidx.number_of_ntypes() num_ntypes = gidx.number_of_ntypes()
list_v = [None] * gidx.number_of_ntypes() num_etypes = gidx.number_of_etypes()
list_e = [None] * gidx.number_of_etypes() list_u = [None] * num_ntypes
list_v = [None] * num_ntypes
list_e = [None] * num_etypes
list_arg_u_nd = [None] * num_ntypes
list_arg_u = [None] * num_ntypes
list_arg_u_ntype_nd = [None] * num_ntypes
list_arg_u_ntype = [None] * num_ntypes
# TODO(Israt): double check ntype or etype
list_arg_e_nd = [None] * num_ntypes
list_arg_e = [None] * num_ntypes
list_arg_e_etype_nd = [None] * num_ntypes
list_arg_e_etype = [None] * num_ntypes
for etid in range(gidx.number_of_etypes()): use_cmp = reduce_op in ['max', 'min']
idtype = getattr(F, gidx.dtype)
for etid in range(num_etypes):
src_id, dst_id = gidx.metagraph.find_edge(etid) src_id, dst_id = gidx.metagraph.find_edge(etid)
u = u_tuple[src_id] if use_u else None u = u_tuple[src_id] if use_u else None
e = e_tuple[etid] if use_e else None e = e_tuple[etid] if use_e else None
...@@ -224,29 +282,42 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple): ...@@ -224,29 +282,42 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
v_shp = (gidx.number_of_nodes(dst_id), ) +\ v_shp = (gidx.number_of_nodes(dst_id), ) +\
infer_broadcast_shape(op, u_shp[1:], e_shp[1:]) infer_broadcast_shape(op, u_shp[1:], e_shp[1:])
list_v[dst_id] = F.zeros(v_shp, dtype, ctx) list_v[dst_id] = F.zeros(v_shp, dtype, ctx)
if use_cmp:
if use_u:
list_arg_u[dst_id] = F.zeros(v_shp, idtype, ctx)
list_arg_u_ntype[dst_id] = F.zeros(v_shp, idtype, ctx)
if use_e:
list_arg_e[dst_id] = F.zeros(v_shp, idtype, ctx)
list_arg_e_etype[dst_id] = F.zeros(v_shp, idtype, ctx)
list_arg_u_nd[dst_id] = to_dgl_nd_for_write(list_arg_u[dst_id])
list_arg_u_ntype_nd[dst_id] = to_dgl_nd_for_write(list_arg_u_ntype[dst_id])
list_arg_e_nd[dst_id] = to_dgl_nd_for_write(list_arg_e[dst_id])
list_arg_e_etype_nd[dst_id] = to_dgl_nd_for_write(list_arg_e_etype[dst_id])
use_cmp = reduce_op in ['max', 'min']
arg_u, arg_e = None, None
idtype = getattr(F, gidx.dtype)
if use_cmp:
if use_u:
arg_u = F.zeros(v_shp, idtype, ctx)
if use_e:
arg_e = F.zeros(v_shp, idtype, ctx)
arg_u_nd = to_dgl_nd_for_write(arg_u)
arg_e_nd = to_dgl_nd_for_write(arg_e)
if gidx.number_of_edges(0) > 0: if gidx.number_of_edges(0) > 0:
_CAPI_DGLKernelSpMMHetero(gidx, op, reduce_op, _CAPI_DGLKernelSpMMHetero(gidx, op, reduce_op,
[to_dgl_nd(u_i) for u_i in list_u], [to_dgl_nd(u_i) for u_i in list_u],
[to_dgl_nd(e_i) for e_i in list_e], [to_dgl_nd(e_i) for e_i in list_e],
[to_dgl_nd_for_write(v_i) for v_i in list_v], [to_dgl_nd_for_write(v_i) for v_i in list_v],
arg_u_nd, list_arg_u_nd, list_arg_e_nd,
arg_e_nd) list_arg_u_ntype_nd, list_arg_e_etype_nd)
arg_u = None if arg_u is None else F.zerocopy_from_dgl_ndarray(arg_u_nd) for l, arg_u_nd in enumerate(list_arg_u_nd):
arg_e = None if arg_e is None else F.zerocopy_from_dgl_ndarray(arg_e_nd) # TODO(Israt): l or src_id as index of lhs
list_arg_u[l] = None if list_arg_u[l] is None else F.zerocopy_from_dgl_ndarray(arg_u_nd)
if expand_u and use_cmp:
list_arg_u[l] = F.squeeze(list_arg_u[l], -1)
for l, arg_e_nd in enumerate(list_arg_e_nd):
list_arg_e[l] = None if list_arg_e[l] is None else F.zerocopy_from_dgl_ndarray(arg_e_nd)
if expand_e and use_cmp:
list_arg_e[l] = F.squeeze(list_arg_e[l], -1)
for l, arg_u_ntype_nd in enumerate(list_arg_u_ntype_nd):
list_arg_u_ntype[l] = None if arg_u_ntype_nd is None \
else F.zerocopy_from_dgl_ndarray(arg_u_ntype_nd)
for l, arg_e_etype_nd in enumerate(list_arg_e_etype_nd):
list_arg_e_etype[l] = None if arg_e_etype_nd is None \
else F.zerocopy_from_dgl_ndarray(arg_e_etype_nd)
# To deal with scalar node/edge features. # To deal with scalar node/edge features.
for l in range(num_ntypes):
for l in range(gidx.number_of_ntypes()):
# replace None by empty tensor. Forward func doesn't accept None in tuple. # replace None by empty tensor. Forward func doesn't accept None in tuple.
v = list_v[l] v = list_v[l]
v = F.tensor([]) if v is None else v v = F.tensor([]) if v is None else v
...@@ -254,12 +325,7 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple): ...@@ -254,12 +325,7 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
v = F.squeeze(v, -1) # To deal with scalar node/edge features. v = F.squeeze(v, -1) # To deal with scalar node/edge features.
list_v[l] = v list_v[l] = v
out = tuple(list_v) out = tuple(list_v)
return out, (list_arg_u, list_arg_e, list_arg_u_ntype, list_arg_e_etype)
if expand_u and use_cmp:
arg_u = F.squeeze(arg_u, -1)
if expand_e and use_cmp:
arg_e = F.squeeze(arg_e, -1)
return out, (arg_u, arg_e)
def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'): def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
...@@ -481,6 +547,50 @@ def _scatter_add(x, idx, m): ...@@ -481,6 +547,50 @@ def _scatter_add(x, idx, m):
return out return out
def _update_grad_minmax_hetero(gidx, op, list_x, list_idx, list_idx_etype, list_dX):
r""" Update gradients for reduce operator max and min (on first dimension) implementation.
Parameters
----------
gidx : HeteroGraphIndex
The input graph index.
list_x : List of tensors
List of the input features.
list_idx : List of tensors
List of the indices array.
list_idx_etype : List of tensors
List of the node- or edge-type array.
list_dX : List of tensors
List of gradients.
Returns
-------
Tensor
The output tensor.
"""
use_u = op != 'copy_rhs'
use_e = op != 'copy_lhs'
list_out = [None] * len(list_dX)
for etid in range(gidx.number_of_etypes()):
src_id, dst_id = gidx.metagraph.find_edge(etid) # gidx is reveresed
x = list_x[src_id]
ctx = F.context(x)
dtype = F.dtype(x)
if use_u:
out_shp = (len(list_dX[dst_id]),) + F.shape(x)[1:]
list_out[dst_id] = F.zeros(out_shp, dtype, ctx)
if use_e:
out_shp = (len(list_dX[etid]),) + F.shape(x)[1:]
list_out[etid] = F.zeros(out_shp, dtype, ctx)
_CAPI_DGLKernelUpdateGradMinMaxHetero(gidx, op,
[to_dgl_nd(x) for x in list_x],
[to_dgl_nd(idx) for idx in list_idx],
[to_dgl_nd(idx_etype) for idx_etype in list_idx_etype],
[to_dgl_nd_for_write(out) for out in list_out])
return tuple(list_out)
def _bwd_segment_cmp(feat, arg, m): def _bwd_segment_cmp(feat, arg, m):
r""" Backward phase of segment reduction (for 'min'/'max' reduction). r""" Backward phase of segment reduction (for 'min'/'max' reduction).
......
...@@ -50,6 +50,19 @@ void ScatterAdd(NDArray feat, ...@@ -50,6 +50,19 @@ void ScatterAdd(NDArray feat,
}); });
} }
/*! \brief Update gradients for reduce operator max/min on heterogeneous graph.*/
template <int XPU, typename IdType, int bits>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
const std::string& op,
const std::vector<NDArray>& feat,
const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out) {
SWITCH_BITS(bits, DType, {
cpu::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
});
}
/*! \brief Backward function of segment cmp.*/ /*! \brief Backward function of segment cmp.*/
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, int bits>
void BackwardSegmentCmp( void BackwardSegmentCmp(
...@@ -121,6 +134,32 @@ template void ScatterAdd<kDLCPU, int64_t, 64>( ...@@ -121,6 +134,32 @@ template void ScatterAdd<kDLCPU, int64_t, 64>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 32>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 32>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 64>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 64>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDLCPU, int32_t, 16>( template void BackwardSegmentCmp<kDLCPU, int32_t, 16>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
......
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <dgl/base_heterograph.h>
#include <vector>
#include <string>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -101,6 +104,86 @@ void ScatterAdd(NDArray feat, NDArray idx, NDArray out) { ...@@ -101,6 +104,86 @@ void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
} }
} }
/*!
* \param graph The input heterogeneous graph.
* \param op The binary operator, could be `copy_u`, `copy_e'.
* \param list_feat List of the input tensors.
* \param list_idx List of the indices tensors.
* \param list_idx_etype List of the node- or edge-type tensors.
* \param list_out List of the output tensors.
*/
template <typename IdType, typename DType>
void UpdateGradMinMax_hetero(HeteroGraphPtr graph,
const std::string& op,
const std::vector<NDArray>& list_feat,
const std::vector<NDArray>& list_idx,
const std::vector<NDArray>& list_idx_ntypes,
std::vector<NDArray>* list_out) {
if (op == "copy_lhs") {
std::vector<std::vector<dgl_id_t>> dst_src_ntids(graph->NumVertexTypes(),
std::vector<dgl_id_t>());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_id_t dst_id = pair.first; // graph is reversed
const dgl_id_t src_id = pair.second;
dst_src_ntids[dst_id].push_back(src_id); // can have duplicates. Use Hashtable to optimize.
}
std::vector<bool> updated(graph->NumVertexTypes());
for (int dst_id = 0; dst_id < dst_src_ntids.size(); ++dst_id) {
std::fill(updated.begin(), updated.end(), false);
for (int j = 0; j < dst_src_ntids[dst_id].size(); ++j) {
int src_id = dst_src_ntids[dst_id][j];
if (updated[src_id]) continue;
const DType* feat_data = list_feat[dst_id].Ptr<DType>();
const IdType* idx_data = list_idx[dst_id].Ptr<IdType>();
const IdType* idx_ntype_data = list_idx_ntypes[dst_id].Ptr<IdType>();
DType* out_data = (*list_out)[src_id].Ptr<DType>();
int dim = 1;
for (int i = 1; i < (*list_out)[src_id]->ndim; ++i)
dim *= (*list_out)[src_id]->shape[i];
int n = list_feat[dst_id]->shape[0];
#pragma omp parallel for
for (int i = 0; i < n; ++i) {
for (int k = 0; k < dim; ++k) {
if (src_id == idx_ntype_data[i * dim + k]) {
const int write_row = idx_data[i * dim + k];
#pragma omp atomic
out_data[write_row * dim + k] += feat_data[i * dim + k]; // feat = dZ
}
}
}
updated[src_id] = true;
}
}
} else if (op == "copy_rhs") {
for (dgl_type_t etid = 0; etid < graph->NumEdgeTypes(); ++etid) {
auto pair = graph->meta_graph()->FindEdge(etid);
const dgl_id_t dst_id = pair.first; // graph is reversed
const dgl_id_t src_id = pair.second;
const DType* feat_data = list_feat[dst_id].Ptr<DType>();
const IdType* idx_data = list_idx[dst_id].Ptr<IdType>();
const IdType* idx_ntype_data = list_idx_ntypes[dst_id].Ptr<IdType>();
DType* out_data = (*list_out)[etid].Ptr<DType>();
int dim = 1;
for (int i = 1; i < (*list_out)[etid]->ndim; ++i)
dim *= (*list_out)[etid]->shape[i];
int n = list_feat[dst_id]->shape[0];
#pragma omp parallel for
for (int i = 0; i < n; ++i) {
for (int k = 0; k < dim; ++k) {
if (etid == idx_ntype_data[i * dim + k]) {
const int write_row = idx_data[i * dim + k];
#pragma omp atomic
out_data[write_row * dim + k] += feat_data[i * dim + k]; // feat = dZ
}
}
}
}
} else {
LOG(FATAL) << "Unsupported binary operator: " << op;
}
}
/*! /*!
* \brief CPU kernel of backward phase of segment min/max. * \brief CPU kernel of backward phase of segment min/max.
* \note math equation: out[arg[i, k], k] = feat[i, k] * \note math equation: out[arg[i, k], k] = feat[i, k]
......
...@@ -54,8 +54,8 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -54,8 +54,8 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_ufeat, const std::vector<NDArray>& vec_ufeat,
const std::vector<NDArray>& vec_efeat, const std::vector<NDArray>& vec_efeat,
std::vector<NDArray> vec_out, std::vector<NDArray>* vec_out,
const std::vector<NDArray>& out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids) { const std::vector<dgl_type_t>& out_node_tids) {
const int64_t dim = bcast.out_len; const int64_t dim = bcast.out_len;
...@@ -69,7 +69,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -69,7 +69,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
CSRMatrix csr = vec_csr[etype]; CSRMatrix csr = vec_csr[etype];
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id]; NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype]; NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray out = vec_out[dst_id]; NDArray out = (*vec_out)[dst_id];
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out); cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
} }
}); });
...@@ -77,23 +77,44 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -77,23 +77,44 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
} else if (reduce == "max" || reduce == "min") { } else if (reduce == "max" || reduce == "min") {
SWITCH_BITS(bits, DType, { SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
std::vector<bool> updated((*vec_out).size(), false);
// TODO(Israt): use vector updated to fill(out...) too
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
if (reduce == "max")
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Max<DType>::zero);
else
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Min<DType>::zero);
const dgl_type_t dst_id = out_node_tids[etype];
if (!updated[dst_id]) {
updated[dst_id] = true;
if (Op::use_lhs) {
IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
std::fill(argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1);
}
if (Op::use_rhs) {
IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
std::fill(arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1);
}
}
}
/* Call SpMM for each relation type */ /* Call SpMM for each relation type */
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) { for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
const dgl_type_t src_id = ufeat_node_tids[etype]; const dgl_type_t src_id = ufeat_node_tids[etype];
const dgl_type_t dst_id = out_node_tids[etype]; const dgl_type_t dst_id = out_node_tids[etype];
CSRMatrix csr = vec_csr[etype]; CSRMatrix csr = vec_csr[etype];
DType *out_off = vec_out[out_node_tids[etype]].Ptr<DType>(); DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id]; NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype]; NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray out = vec_out[dst_id]; NDArray out = (*vec_out)[dst_id];
if (reduce == "max") { if (reduce == "max") {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero); cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Max<DType>>(
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>( bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
} else { } else {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero); cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Min<DType>>(
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>( bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
} }
} }
}); });
...@@ -132,42 +153,42 @@ template void SpMMCsrHetero<kDLCPU, int32_t, 16>( ...@@ -132,42 +153,42 @@ template void SpMMCsrHetero<kDLCPU, int32_t, 16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 16>( template void SpMMCsrHetero<kDLCPU, int64_t, 16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 32>( template void SpMMCsrHetero<kDLCPU, int32_t, 32>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 32>( template void SpMMCsrHetero<kDLCPU, int64_t, 32>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 64>( template void SpMMCsrHetero<kDLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 64>( template void SpMMCsrHetero<kDLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
......
...@@ -309,6 +309,93 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, ...@@ -309,6 +309,93 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
#endif // _WIN32 #endif // _WIN32
} }
/*!
* \brief CPU kernel of SpMM-Min/Max on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param ufeat The feature on source nodes.
* \param efeat The feature on edges.
* \param out The result feature on destination nodes.
* \param argu Arg-Min/Max on source nodes, which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max
* reducer. \param arge Arg-Min/Max on edges. which refers the source node
* indices correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max
* reducer. \note It uses node parallel strategy, different threads are
* responsible for the computation of different nodes. \note The result will
* contain infinity for zero-degree nodes.
*/
template <typename IdType, typename DType, typename Op, typename Cmp>
void SpMMCmpCsrHetero(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
NDArray efeat, NDArray out, NDArray argu, NDArray arge,
NDArray argu_ntype, NDArray arge_etype,
const int ntype, const int etype) {
const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
const IdType* indices = static_cast<IdType*>(csr.indices->data);
const IdType* edges =
has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
const DType* X = Op::use_lhs ? static_cast<DType*>(ufeat->data) : nullptr;
const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
rhs_dim = bcast.rhs_len;
DType* O = static_cast<DType*>(out->data);
IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;
IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;
IdType* argX_ntype = Op::use_lhs ? static_cast<IdType*>(argu_ntype->data) : nullptr;
IdType* argW_etype = Op::use_rhs ? static_cast<IdType*>(arge_etype->data) : nullptr;
CHECK_NOTNULL(indptr);
CHECK_NOTNULL(O);
if (Op::use_lhs) {
CHECK_NOTNULL(indices);
CHECK_NOTNULL(X);
CHECK_NOTNULL(argX);
}
if (Op::use_rhs) {
if (has_idx)
CHECK_NOTNULL(edges);
CHECK_NOTNULL(W);
CHECK_NOTNULL(argW);
}
// TODO(Israt): Use LIBXSMM. Homogeneous graph uses LIBXMM when enabled.
runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
for (auto rid = b; rid < e; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim;
IdType* argx_off = argX + rid * dim;
IdType* argw_off = argW + rid * dim;
IdType* argx_ntype = argX_ntype + rid * dim;
IdType* argw_etype = argW_etype + rid * dim;
for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j];
const IdType eid = has_idx ? edges[j] : j;
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off =
Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off =
Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
const DType val = Op::Call(lhs_off, rhs_off);
if (Cmp::Call(out_off[k], val)) {
out_off[k] = val;
if (Op::use_lhs) {
argx_off[k] = cid;
argx_ntype[k] = ntype;
}
if (Op::use_rhs) {
argw_off[k] = eid;
argw_etype[k] = etype;
}
}
}
}
}
});
}
/*! /*!
* \brief CPU kernel of SpMM-Min/Max on Coo format. * \brief CPU kernel of SpMM-Min/Max on Coo format.
* \param bcast Broadcast information. * \param bcast Broadcast information.
......
...@@ -4,10 +4,12 @@ ...@@ -4,10 +4,12 @@
* \brief Segment reduce C APIs and definitions. * \brief Segment reduce C APIs and definitions.
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include "./segment_reduce.cuh" #include "./segment_reduce.cuh"
#include "./functor.cuh" #include "./functor.cuh"
#include "./utils.h" #include "./utils.h"
namespace dgl { namespace dgl {
using namespace cuda; using namespace cuda;
...@@ -48,6 +50,19 @@ void ScatterAdd(NDArray feat, ...@@ -48,6 +50,19 @@ void ScatterAdd(NDArray feat,
} }
template <int XPU, typename IdType, int bits>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
const std::string& op,
const std::vector<NDArray>& feat,
const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out) {
SWITCH_BITS(bits, DType, {
LOG(FATAL) << "Not implemented. Please use CPU version.";
});
}
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, int bits>
void BackwardSegmentCmp(NDArray feat, void BackwardSegmentCmp(NDArray feat,
NDArray arg, NDArray arg,
...@@ -118,6 +133,32 @@ template void ScatterAdd<kDLGPU, int64_t, 64>( ...@@ -118,6 +133,32 @@ template void ScatterAdd<kDLGPU, int64_t, 64>(
NDArray feat, NDArray feat,
NDArray idx, NDArray idx,
NDArray out); NDArray out);
template void UpdateGradMinMax_hetero<kDLGPU, int32_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int64_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int32_t, 32>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int64_t, 32>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int32_t, 64>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int64_t, 64>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDLGPU, int32_t, 16>( template void BackwardSegmentCmp<kDLGPU, int32_t, 16>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
......
...@@ -515,15 +515,15 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -515,15 +515,15 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_ufeat, const std::vector<NDArray>& vec_ufeat,
const std::vector<NDArray>& vec_efeat, const std::vector<NDArray>& vec_efeat,
std::vector<NDArray> vec_out, std::vector<NDArray>* vec_out,
const std::vector<NDArray>& out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, // ufeat node type id const std::vector<dgl_type_t>& ufeat_ntids, // ufeat node type id
const std::vector<dgl_type_t>& out_ntids) { // output node type id const std::vector<dgl_type_t>& out_ntids) { // output node type id
bool is_scalar_efeat = vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0]; bool is_scalar_efeat = vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];
bool use_efeat = op != "copy_lhs"; bool use_efeat = op != "copy_lhs";
auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx); auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
SWITCH_BITS(bits, DType, { SWITCH_BITS(bits, DType, {
std::vector<DType*> trans_out(vec_out.size(), NULL); std::vector<DType*> trans_out((*vec_out).size(), NULL);
bool use_legacy_cusparsemm = bool use_legacy_cusparsemm =
(CUDART_VERSION < 11000) && (CUDART_VERSION < 11000) &&
...@@ -532,9 +532,9 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -532,9 +532,9 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
(op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>(false))); (op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>(false)));
// Create temporary output buffer to store non-transposed output // Create temporary output buffer to store non-transposed output
if (use_legacy_cusparsemm) { if (use_legacy_cusparsemm) {
for (dgl_type_t ntype = 0; ntype < vec_out.size(); ++ntype) { for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
const int m = vec_out[ntype]->shape[0]; const int m = (*vec_out)[ntype]->shape[0];
const int n = vec_out[ntype]->shape[1]; const int n = (*vec_out)[ntype]->shape[1];
if (m == 0) continue; if (m == 0) continue;
DType *out = static_cast<DType*>(device->AllocWorkspace(vec_csr[0].indptr->ctx, DType *out = static_cast<DType*>(device->AllocWorkspace(vec_csr[0].indptr->ctx,
m * n * sizeof(DType))); m * n * sizeof(DType)));
...@@ -577,7 +577,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -577,7 +577,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
if (op == "copy_lhs" && cusparse_available<bits, IdType>(more_nnz)) { // cusparse if (op == "copy_lhs" && cusparse_available<bits, IdType>(more_nnz)) { // cusparse
/* If CUDA is less than 11.0, put the output in trans_out for later transposition */ /* If CUDA is less than 11.0, put the output in trans_out for later transposition */
DType *out = (CUDART_VERSION < 11000) ? trans_out[dst_id] : DType *out = (CUDART_VERSION < 11000) ? trans_out[dst_id] :
static_cast<DType*>(vec_out[dst_id]->data); static_cast<DType*>((*vec_out)[dst_id]->data);
cusparse::CusparseCsrmm2Hetero<DType, IdType>( cusparse::CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr, csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data), static_cast<DType*>(vec_ufeat[src_id]->data),
...@@ -593,8 +593,8 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -593,8 +593,8 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
csr.indptr->ctx, csr, csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data), static_cast<DType*>(vec_ufeat[src_id]->data),
static_cast<DType*>(efeat->data), static_cast<DType*>(efeat->data),
// TODO(Israt): Change vec_out to trans_out to support CUDA version < 11 // TODO(Israt): Change (*vec_out) to trans_out to support CUDA version < 11
static_cast<DType*>(vec_out[dst_id]->data), static_cast<DType*>((*vec_out)[dst_id]->data),
x_length, thr_entry->stream); x_length, thr_entry->stream);
} else { // general kernel } else { // general kernel
NDArray ufeat = (vec_ufeat.size() == 0) ? NDArray ufeat = (vec_ufeat.size() == 0) ?
...@@ -603,27 +603,10 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -603,27 +603,10 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
NullArray() : vec_efeat[etype]; NullArray() : vec_efeat[etype];
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >( cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
bcast, csr, ufeat, efeat, vec_out[dst_id], NullArray(), NullArray()); bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(), NullArray());
}); });
} }
} else if (reduce == "max") { // TODO(Israt): Add support for max/min reducer
SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
bcast, csr, ufeat, efeat, vec_out[dst_id], out_aux[0], out_aux[1]);
});
} else if (reduce == "min") {
SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
bcast, csr, ufeat, efeat, vec_out[dst_id], out_aux[0], out_aux[1]);
});
} else { } else {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
} }
...@@ -631,11 +614,11 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -631,11 +614,11 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
if (use_legacy_cusparsemm) { if (use_legacy_cusparsemm) {
// transpose output // transpose output
for (dgl_type_t ntype = 0; ntype < vec_out.size(); ++ntype) { for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
const int m = vec_out[ntype]->shape[0]; const int m = (*vec_out)[ntype]->shape[0];
const int n = vec_out[ntype]->shape[1]; const int n = (*vec_out)[ntype]->shape[1];
if (m == 0) continue; if (m == 0) continue;
DType *C_data = static_cast<DType*>(vec_out[ntype]->data); DType *C_data = static_cast<DType*>((*vec_out)[ntype]->data);
_Transpose(trans_out[ntype], C_data, n, m); _Transpose(trans_out[ntype], C_data, n, m);
device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]); device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]);
} }
...@@ -709,37 +692,37 @@ template void SpMMCsrHetero<kDLGPU, int32_t, 16>( ...@@ -709,37 +692,37 @@ template void SpMMCsrHetero<kDLGPU, int32_t, 16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 16>( template void SpMMCsrHetero<kDLGPU, int64_t, 16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int32_t, 32>( template void SpMMCsrHetero<kDLGPU, int32_t, 32>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 32>( template void SpMMCsrHetero<kDLGPU, int64_t, 32>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int32_t, 64>( template void SpMMCsrHetero<kDLGPU, int32_t, 64>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 64>( template void SpMMCsrHetero<kDLGPU, int64_t, 64>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCoo<kDLGPU, int32_t, 16>( template void SpMMCoo<kDLGPU, int32_t, 16>(
......
...@@ -55,30 +55,35 @@ void SpMM(const std::string& op, const std::string& reduce, ...@@ -55,30 +55,35 @@ void SpMM(const std::string& op, const std::string& reduce,
/*! \brief Generalized Sparse Matrix-Matrix Multiplication with hetero-graph support. */ /*! \brief Generalized Sparse Matrix-Matrix Multiplication with hetero-graph support. */
void SpMMHetero(const std::string& op, const std::string& reduce, void SpMMHetero(const std::string& op, const std::string& reduce,
HeteroGraphPtr graph, HeteroGraphPtr graph,
std::vector<NDArray> ufeat_vec, const std::vector<NDArray>& ufeat_vec,
std::vector<NDArray> efeat_vec, const std::vector<NDArray>& efeat_vec,
std::vector<NDArray> out, std::vector<NDArray>* out,
std::vector<NDArray> out_aux) { std::vector<std::vector<NDArray>>* out_aux) {
SparseFormat format = graph->SelectFormat(0, CSC_CODE); SparseFormat format = graph->SelectFormat(0, CSC_CODE);
std::vector<CSRMatrix> vec_graph; std::vector<CSRMatrix> vec_graph;
std::vector<dgl_type_t> ufeat_eid; std::vector<dgl_type_t> ufeat_eid;
std::vector<dgl_type_t> efeat_eid; std::vector<dgl_type_t> efeat_eid;
std::vector<dgl_type_t> out_eid; std::vector<dgl_type_t> out_eid;
auto pair = graph->meta_graph()->FindEdge(0); // first etype
NDArray ufeat_etype0 = (ufeat_vec.size() == 0) ? NullArray() : ufeat_vec[pair.first];
NDArray efeat_etype0 = (efeat_vec.size() == 0) ? NullArray() : efeat_vec[0];
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
vec_graph.push_back(graph->GetCSCMatrix(etype)); vec_graph.push_back(graph->GetCSCMatrix(etype));
auto pair = graph->meta_graph()->FindEdge(etype); auto pair = graph->meta_graph()->FindEdge(etype);
ufeat_eid.push_back(pair.first); ufeat_eid.push_back(pair.first);
efeat_eid.push_back(etype); efeat_eid.push_back(etype);
out_eid.push_back(pair.second); out_eid.push_back(pair.second);
if (ufeat_etype0->shape[1] != ufeat_vec[pair.first]->shape[1])
LOG(FATAL) << "Column width of the input node features of all etypes must be same.";
if (efeat_etype0->shape[1] != efeat_vec[etype]->shape[1])
LOG(FATAL) << "Column width of the input edge features of all etypes must be same.";
} }
NDArray efeat = (efeat_vec.size() == 0) ? NullArray() : efeat_vec[efeat_eid[0]]; const auto& bcast = CalcBcastOff(op, ufeat_etype0, efeat_etype0);
NDArray ufeat = (ufeat_vec.size() == 0) ? NullArray() : ufeat_vec[ufeat_eid[0]];
const auto& bcast = CalcBcastOff(op, ufeat, efeat);
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", { ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out[out_eid[0]]->dtype, bits, "Feature data", { ATEN_FLOAT_BITS_SWITCH((*out)[out_eid[0]]->dtype, bits, "Feature data", {
if (format == SparseFormat::kCSC) { if (format == SparseFormat::kCSC) {
SpMMCsrHetero<XPU, IdType, bits>( SpMMCsrHetero<XPU, IdType, bits>(
op, reduce, bcast, vec_graph, op, reduce, bcast, vec_graph,
...@@ -226,6 +231,24 @@ void ScatterAddDispatch(NDArray feat, NDArray idx, NDArray out) { ...@@ -226,6 +231,24 @@ void ScatterAddDispatch(NDArray feat, NDArray idx, NDArray out) {
}); });
} }
/*! \brief Update gradients (reduce op max/min) dispatch function on heterogeneous graph. */
void UpdateGradMinMaxDispatchHetero(const HeteroGraphPtr& graph,
const std::string& op,
const std::vector<NDArray>& feat,
const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out) {
auto pair = graph->meta_graph()->FindEdge(0); // checking the first etype
auto src_id = pair.first;
ATEN_XPU_SWITCH_CUDA(feat[src_id]->ctx.device_type, XPU, "ScatterAdd", {
ATEN_ID_TYPE_SWITCH(idx[src_id]->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(feat[src_id]->dtype, bits, "Feature data", {
UpdateGradMinMax_hetero<XPU, IdType, bits>(graph, op, feat, idx, idx_etype, out);
});
});
});
}
/*! \brief Backward segment cmp dispatch function.*/ /*! \brief Backward segment cmp dispatch function.*/
void BackwardSegmentCmpDispatch(NDArray feat, NDArray arg, NDArray out) { void BackwardSegmentCmpDispatch(NDArray feat, NDArray arg, NDArray out) {
ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "BackwardSegmentCmp", { ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "BackwardSegmentCmp", {
...@@ -333,35 +356,33 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMMHetero") ...@@ -333,35 +356,33 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMMHetero")
List<Value> list_U = args[3]; List<Value> list_U = args[3];
List<Value> list_E = args[4]; List<Value> list_E = args[4];
List<Value> list_V = args[5]; List<Value> list_V = args[5];
NDArray ArgU = args[6]; List<Value> list_ArgU = args[6];
NDArray ArgE = args[7]; List<Value> list_ArgE = args[7];
std::vector<NDArray> U_vec; List<Value> list_ArgU_ntype = args[8];
std::vector<NDArray> V_vec; List<Value> list_ArgE_etype = args[9];
std::vector<NDArray> E_vec; std::vector<std::vector<NDArray>> Arg_vec; // ArgU + ArgE
U_vec.reserve(list_U.size()); for (int i = 0; i < 4; ++i) { // ArgU + ArgE + ArgU_ntype + ArgE_etype
V_vec.reserve(list_V.size()); Arg_vec.push_back(std::vector<NDArray>());
E_vec.reserve(list_E.size());
for (Value val : list_U) {
U_vec.push_back(val->data);
}
for (Value val : list_V) {
V_vec.push_back(val->data);
}
for (Value val : list_E) {
E_vec.push_back(val->data);
} }
std::vector<NDArray> U_vec = ListValueToVector<NDArray>(list_U);
std::vector<NDArray> V_vec = ListValueToVector<NDArray>(list_V);
std::vector<NDArray> E_vec = ListValueToVector<NDArray>(list_E);
Arg_vec[0] = ListValueToVector<NDArray>(list_ArgU);
Arg_vec[1] = ListValueToVector<NDArray>(list_ArgE);
Arg_vec[2] = ListValueToVector<NDArray>(list_ArgU_ntype);
Arg_vec[3] = ListValueToVector<NDArray>(list_ArgE_etype);
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype); auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_id_t src_id = pair.first; const dgl_id_t src_id = pair.first;
const dgl_id_t dst_id = pair.second; const dgl_id_t dst_id = pair.second;
NDArray U = (U_vec.size() == 0) ? NullArray() : U_vec[src_id]; NDArray U = (U_vec.size() == 0) ? NullArray() : U_vec[src_id];
NDArray E = (E_vec.size() == 0) ? NullArray() : E_vec[etype]; NDArray E = (E_vec.size() == 0) ? NullArray() : E_vec[etype];
CheckCtx(graph->Context(), {U, E, V_vec[dst_id], ArgU, ArgE}, CheckCtx(graph->Context(), {U, E, V_vec[dst_id], Arg_vec[0][dst_id], Arg_vec[1][dst_id]},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"}); {"U_data", "E_data", "out", "Arg_U", "Arg_E"});
CheckContiguous({U, E, V_vec[dst_id], ArgU, ArgE}, CheckContiguous({U, E, V_vec[dst_id], Arg_vec[0][dst_id], Arg_vec[1][dst_id]},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"}); {"U_data", "E_data", "out", "Arg_U", "Arg_E"});
} }
SpMMHetero(op, reduce_op, graph.sptr(), U_vec, E_vec, V_vec, {ArgU, ArgE}); SpMMHetero(op, reduce_op, graph.sptr(), U_vec, E_vec, &V_vec, &Arg_vec);
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM")
...@@ -440,6 +461,23 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelScatterAdd") ...@@ -440,6 +461,23 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelScatterAdd")
ScatterAddDispatch(feat, idx, out); ScatterAddDispatch(feat, idx, out);
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelUpdateGradMinMaxHetero")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
List<Value> list_feat = args[2];
List<Value> list_idx = args[3];
List<Value> list_idx_etype = args[4];
List<Value> list_out = args[5];
std::vector<NDArray> vec_feat = ListValueToVector<NDArray>(list_feat);
std::vector<NDArray> vec_idx = ListValueToVector<NDArray>(list_idx);
std::vector<NDArray> vec_idx_etype = ListValueToVector<NDArray>(list_idx_etype);
std::vector<NDArray> vec_out = ListValueToVector<NDArray>(list_out);
// CheckCtx(feat->ctx, {feat, idx, out}, {"feat", "idx", "out"});
// CheckContiguous({feat, idx, out}, {"feat", "idx", "out"});
UpdateGradMinMaxDispatchHetero(graph.sptr(), op, vec_feat, vec_idx, vec_idx_etype, &vec_out);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelBwdSegmentCmp") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelBwdSegmentCmp")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
NDArray feat = args[0]; NDArray feat = args[0];
......
...@@ -39,8 +39,8 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -39,8 +39,8 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, std::vector<NDArray>* out,
const std::vector<NDArray>& out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_eid, const std::vector<dgl_type_t>& ufeat_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
/*! /*!
...@@ -130,6 +130,17 @@ void ScatterAdd(NDArray feat, ...@@ -130,6 +130,17 @@ void ScatterAdd(NDArray feat,
NDArray idx, NDArray idx,
NDArray out); NDArray out);
/*!
* \brief Update gradients for reduce operator max and min on first dimension.
*/
template <int XPU, typename IdType, int bits>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
const std::string& op,
const std::vector<NDArray>& feat,
const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out);
/*! /*!
* \brief Backward function of segment cmp. * \brief Backward function of segment cmp.
*/ */
......
...@@ -12,12 +12,12 @@ from dgl import DGLError ...@@ -12,12 +12,12 @@ from dgl import DGLError
import test_utils import test_utils
from test_utils import parametrize_dtype, get_cases from test_utils import parametrize_dtype, get_cases
from scipy.sparse import rand from scipy.sparse import rand
rfuncs = {'sum': fn.sum, 'max': fn.max, 'min': fn.min, 'mean': fn.mean} rfuncs = {'sum': fn.sum, 'max': fn.max, 'min': fn.min, 'mean': fn.mean}
fill_value = {'sum': 0, 'max': float("-inf")} fill_value = {'sum': 0, 'max': float("-inf")}
feat_size = 2 feat_size = 2
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') @unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@unittest.skipIf(F._default_context_str == 'gpu', reason="Max/min reducer not supported on GPU yet.")
def create_test_heterograph(idtype): def create_test_heterograph(idtype):
# test heterograph from the docstring, plus a user -- wishes -- game relation # test heterograph from the docstring, plus a user -- wishes -- game relation
...@@ -38,15 +38,44 @@ def create_test_heterograph(idtype): ...@@ -38,15 +38,44 @@ def create_test_heterograph(idtype):
assert g.device == F.ctx() assert g.device == F.ctx()
return g return g
def create_test_heterograph_2(idtype):
src = np.random.randint(0, 5, 25)
dst = np.random.randint(0, 5, 25)
g = dgl.heterograph({
('user', 'becomes', 'player'): (src, dst),
('user', 'follows', 'user'): (src, dst),
('user', 'plays', 'game'): (src, dst),
('user', 'wishes', 'game'): (src, dst),
('developer', 'develops', 'game'): (src, dst),
}, idtype=idtype, device=F.ctx())
assert g.idtype == idtype
assert g.device == F.ctx()
return g
def create_test_heterograph_large(idtype):
src = np.random.randint(0, 50, 2500)
dst = np.random.randint(0, 50, 2500)
g = dgl.heterograph({
('user', 'follows', 'user'): (src, dst),
('user', 'plays', 'game'): (src, dst),
('user', 'wishes', 'game'): (src, dst),
('developer', 'develops', 'game'): (src, dst),
}, idtype=idtype, device=F.ctx())
assert g.idtype == idtype
assert g.device == F.ctx()
return g
@parametrize_dtype @parametrize_dtype
def test_unary_copy_u(idtype): def test_unary_copy_u(idtype):
def _test(mfunc, rfunc): def _test(mfunc, rfunc):
g = create_test_heterograph_2(idtype)
g = create_test_heterograph(idtype) g0 = create_test_heterograph(idtype)
g1 = create_test_heterograph_large(idtype)
cross_reducer = rfunc.__name__
x1 = F.randn((g.num_nodes('user'), feat_size)) x1 = F.randn((g.num_nodes('user'), feat_size))
x2 = F.randn((g.num_nodes('developer'), feat_size)) x2 = F.randn((g.num_nodes('developer'), feat_size))
F.attach_grad(x1) F.attach_grad(x1)
F.attach_grad(x2) F.attach_grad(x2)
g.nodes['user'].data['h'] = x1 g.nodes['user'].data['h'] = x1
...@@ -58,56 +87,63 @@ def test_unary_copy_u(idtype): ...@@ -58,56 +87,63 @@ def test_unary_copy_u(idtype):
with F.record_grad(): with F.record_grad():
g.multi_update_all( g.multi_update_all(
{'plays' : (mfunc('h', 'm'), rfunc('m', 'y')), {etype : (mfunc('h', 'm'), rfunc('m', 'y'))
'follows': (mfunc('h', 'm'), rfunc('m', 'y')), for etype in g.canonical_etypes},
'develops': (mfunc('h', 'm'), rfunc('m', 'y')), cross_reducer)
'wishes': (mfunc('h', 'm'), rfunc('m', 'y'))}, r1 = g.nodes['game'].data['y'].clone()
'sum') r2 = g.nodes['user'].data['y'].clone()
r1 = g.nodes['game'].data['y'] r3 = g.nodes['player'].data['y'].clone()
F.backward(r1, F.randn(r1.shape)) loss = r1.sum() + r2.sum() + r3.sum()
n_grad1 = F.grad(g.nodes['user'].data['h']) F.backward(loss)
g.nodes['game'].data.clear() n_grad1 = F.grad(g.nodes['user'].data['h']).clone()
n_grad2 = F.grad(g.nodes['developer'].data['h']).clone()
g.nodes['user'].data.clear()
g.nodes['developer'].data.clear()
g.nodes['game'].data.clear()
g.nodes['player'].data.clear()
################################################################# #################################################################
# update_all(): call msg_passing for all etypes # update_all(): call msg_passing for all etypes
################################################################# #################################################################
g.update_all(mfunc('h', 'm'), rfunc('m', 'y')) F.attach_grad(x1)
r2 = g.nodes['game'].data['y'] F.attach_grad(x2)
F.backward(r2, F.randn(r2.shape)) g.nodes['user'].data['h'] = x1
n_grad2 = F.grad(g.nodes['user'].data['h']) g.nodes['developer'].data['h'] = x2
# correctness check
def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y))
if not F.allclose(r1, r2):
_print_error(r1, r2)
assert F.allclose(r1, r2)
if not F.allclose(n_grad1, n_grad2):
print('node grad')
_print_error(n_grad1, n_grad2)
assert(F.allclose(n_grad1, n_grad2))
with F.record_grad():
g.update_all(mfunc('h', 'm'), rfunc('m', 'y'))
r4 = g.nodes['game'].data['y']
r5 = g.nodes['user'].data['y']
r6 = g.nodes['player'].data['y']
loss = r4.sum() + r5.sum() + r6.sum()
F.backward(loss)
n_grad3 = F.grad(g.nodes['user'].data['h'])
n_grad4 = F.grad(g.nodes['developer'].data['h'])
assert F.allclose(r1, r4)
assert F.allclose(r2, r5)
assert F.allclose(r3, r6)
assert(F.allclose(n_grad1, n_grad3))
assert(F.allclose(n_grad2, n_grad4))
_test(fn.copy_u, fn.sum) _test(fn.copy_u, fn.sum)
# TODO(Israt) :Add reduce func to suport the following reduce op _test(fn.copy_u, fn.max)
# _test('copy_u', 'max') _test(fn.copy_u, fn.min)
# _test('copy_u', 'min')
# _test('copy_u', 'mean') # _test('copy_u', 'mean')
@parametrize_dtype @parametrize_dtype
def test_unary_copy_e(idtype): def test_unary_copy_e(idtype):
def _test(mfunc, rfunc): def _test(mfunc, rfunc):
g = create_test_heterograph(idtype) g = create_test_heterograph_large(idtype)
feat_size = 2 g0 = create_test_heterograph_2(idtype)
g1 = create_test_heterograph(idtype)
x1 = F.randn((4,feat_size)) cross_reducer = rfunc.__name__
x2 = F.randn((4,feat_size)) x1 = F.randn((g.num_edges('plays'),feat_size))
x3 = F.randn((3,feat_size)) x2 = F.randn((g.num_edges('follows'),feat_size))
x4 = F.randn((3,feat_size)) x3 = F.randn((g.num_edges('develops'),feat_size))
x4 = F.randn((g.num_edges('wishes'),feat_size))
F.attach_grad(x1) F.attach_grad(x1)
F.attach_grad(x2) F.attach_grad(x2)
F.attach_grad(x3) F.attach_grad(x3)
...@@ -127,44 +163,60 @@ def test_unary_copy_e(idtype): ...@@ -127,44 +163,60 @@ def test_unary_copy_e(idtype):
'follows': (mfunc('eid', 'm'), rfunc('m', 'y')), 'follows': (mfunc('eid', 'm'), rfunc('m', 'y')),
'develops': (mfunc('eid', 'm'), rfunc('m', 'y')), 'develops': (mfunc('eid', 'm'), rfunc('m', 'y')),
'wishes': (mfunc('eid', 'm'), rfunc('m', 'y'))}, 'wishes': (mfunc('eid', 'm'), rfunc('m', 'y'))},
'sum') cross_reducer)
r1 = g.nodes['game'].data['y'] r1 = g.nodes['game'].data['y'].clone()
F.backward(r1, F.randn(r1.shape)) r2 = g.nodes['user'].data['y'].clone()
e_grad1 = F.grad(g['develops'].edata['eid']) loss = r1.sum() + r2.sum()
F.backward(loss)
e_grad1 = F.grad(g['develops'].edata['eid']).clone()
e_grad2 = F.grad(g['plays'].edata['eid']).clone()
e_grad3 = F.grad(g['wishes'].edata['eid']).clone()
e_grad4 = F.grad(g['follows'].edata['eid']).clone()
{etype : (g[etype].edata.clear())
for _, etype, _ in g.canonical_etypes},
################################################################# #################################################################
# update_all(): call msg_passing for all etypes # update_all(): call msg_passing for all etypes
################################################################# #################################################################
# TODO(Israt): output type can be None in multi_update and empty # TODO(Israt): output type can be None in multi_update and empty
# tensor in new_update_all F.attach_grad(x1)
g.update_all(mfunc('eid', 'm'), rfunc('m', 'y')) F.attach_grad(x2)
r2 = g.nodes['game'].data['y'] F.attach_grad(x3)
F.backward(r2, F.randn(r2.shape)) F.attach_grad(x4)
e_grad2 = F.grad(g['develops'].edata['eid'])
g['plays'].edata['eid'] = x1
g['follows'].edata['eid'] = x2
g['develops'].edata['eid'] = x3
g['wishes'].edata['eid'] = x4
with F.record_grad():
g.update_all(mfunc('eid', 'm'), rfunc('m', 'y'))
r3 = g.nodes['game'].data['y']
r4 = g.nodes['user'].data['y']
loss = r3.sum() + r4.sum()
F.backward(loss)
e_grad5 = F.grad(g['develops'].edata['eid'])
e_grad6 = F.grad(g['plays'].edata['eid'])
e_grad7 = F.grad(g['wishes'].edata['eid'])
e_grad8 = F.grad(g['follows'].edata['eid'])
# # correctness check # # correctness check
def _print_error(a, b): def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())): for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
if not np.allclose(x, y): if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y)) print('@{} {} v.s. {}'.format(i, x, y))
if not F.allclose(r1, r2): assert F.allclose(r1, r3)
_print_error(r1, r2) assert F.allclose(r2, r4)
assert F.allclose(r1, r2) assert(F.allclose(e_grad1, e_grad5))
if not F.allclose(e_grad1, e_grad2): assert(F.allclose(e_grad2, e_grad6))
print('edge grad') assert(F.allclose(e_grad3, e_grad7))
_print_error(e_grad1, e_grad2) assert(F.allclose(e_grad4, e_grad8))
assert(F.allclose(e_grad1, e_grad2))
_test(fn.copy_e, fn.sum) _test(fn.copy_e, fn.sum)
# TODO(Israt) :Add reduce func to suport the following reduce op _test(fn.copy_e, fn.max)
# _test('copy_e', 'max') _test(fn.copy_e, fn.min)
# _test('copy_e', 'min')
# _test('copy_e', 'mean') # _test('copy_e', 'mean')
@parametrize_dtype @parametrize_dtype
def test_binary_op(idtype): def test_binary_op(idtype):
def _test(lhs, rhs, binary_op, reducer): def _test(lhs, rhs, binary_op, reducer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment