Unverified Commit cb0e1103 authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Feature] Add Min/max reducer in heterogeneous API for unary message functions (#3514)



* min/max support for forward CPU heterograph

* Added etype with each argU values

* scatter_add needs fix

* added scatter_add_hetero. Grads dont match for max reducer

* storing ntype in argX

* fixing scatter_add_hetero

* hetero matches with torch's scatter add

* works copy_e forward+cpu

* added backward for copy_rhs

* Computes gradient for all node types in one kernel

* bug fix

* unnitest for max/min on CPU

* renamed scatter_add_hetero to update_grad_minmax_hetero

* lint check and comment out cuda call for max. Code is for CPU only

* lint check

* replace inf with zero

* minor

* lint check

* removed LIBXSMM code from hetro code

* fixing backward operator of UpdateGradMinMaxHetero

* removed backward from update_grad_minmax_hetero

* docstring

* improved docstring and coding style

* Added pass by pointer for output

* typos and pass by references

* Support for copy_rhs

* Added header <string>

* fix bug in copy_u_max

* Added comments and dimension check of all etypes

* skip mxnet check

* pass by pointer output arrays

* updated docstring
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 769718df
import torch as th
from distutils.version import LooseVersion
from ...base import is_all, ALL
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...sparse import _csrmm, _csrsum, _csrmask
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero
from ...heterograph_index import create_unitgraph_from_csr
if LooseVersion(th.__version__) >= LooseVersion("1.6.0"):
......@@ -194,10 +194,9 @@ class GSpMM_hetero(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, op, reduce_op, X_len, *feats): # feats = lhs_data + rhs_data
out, (argX, argY) = _gspmm_hetero(gidx, op, reduce_op, X_len, feats)
out, (argX, argY, argX_ntype, argY_etype) = _gspmm_hetero(gidx, op, reduce_op, X_len, feats)
X, Y = feats[:X_len], feats[X_len:]
# TODO (Israt): check target to decide src_id/dst_id?
# checking the first relation to decide for all the relations
src_id, dst_id = gidx.metagraph.find_edge(0)
reduce_last = _need_reduce_last_dim(X[src_id], Y[dst_id])
X_shape = tuple([X[i].shape if X[i] is not None else None
......@@ -214,11 +213,11 @@ class GSpMM_hetero(th.autograd.Function):
# checking the first relation to decide for all the relations
if not spmm_cache_argX(op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]):
argX = None
argX = tuple([None] * len(X))
if not spmm_cache_argY(op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]):
argY = None
argY = tuple([None] * len(X))
ctx.save_for_backward(*feats, argX, argY)
ctx.save_for_backward(*feats, *argX, *argX_ntype, *argY, *argY_etype )
return out
@staticmethod
......@@ -226,14 +225,16 @@ class GSpMM_hetero(th.autograd.Function):
def backward(ctx, *dZ):
gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len = ctx.backward_cache
ctx.backward_cache = None
feats = ctx.saved_tensors[:-2]
argX = ctx.saved_tensors[-2]
argY = ctx.saved_tensors[-1]
num_ntypes = gidx.number_of_ntypes()
feats = ctx.saved_tensors[:-(4 * num_ntypes)]
argX = ctx.saved_tensors[-(4 * num_ntypes):-(3 * num_ntypes)]
argX_ntype = ctx.saved_tensors[-(3 * num_ntypes):-(2 * num_ntypes)]
argY = ctx.saved_tensors[-(2 * num_ntypes):- num_ntypes]
argY_etype = ctx.saved_tensors[-num_ntypes:]
X, Y = feats[:X_len], feats[X_len:]
if op != 'copy_rhs' and any([x is not None for x in X]):
g_rev = gidx.reverse()
# TODO(Israt): implement other combinations of message and reduce functions
if reduce_op == 'sum':
if op == 'mul':
dX = gspmm_hetero(g_rev, 'mul', 'sum', len(X), *tuple(dZ + Y))
......@@ -242,6 +243,17 @@ class GSpMM_hetero(th.autograd.Function):
elif op == 'copy_lhs':
tpl_None = tuple([None] * len(Y))
dX = gspmm_hetero(g_rev, 'copy_lhs', 'sum', len(X), *tuple(dZ + tpl_None))
else: # max/min
# Assuming that the features are of the same dimension (enforced by the forward function)
src_id, dst_id = gidx.metagraph.find_edge(0)
dX = tuple([th.zeros((X_shape[i][0],) + dZ[dst_id].shape[1:], dtype=dtype, device=device)
if X[i] is not None else None for i in range(len(X))])
if op == 'mul':
grad = _expand(Y, dZ.shape[1:]).gather(
0, argY.long()) * dZ
dX.scatter_add_(0, argX.long(), grad)
elif op in ['add', 'copy_lhs']:
dX = _update_grad_minmax_hetero(g_rev, op, dZ, argX, argX_ntype, dX)
dX = tuple([_reduce_grad(dX[i], X_shape[i]) if X[i] is not None else None
for i in range(len(X))])
else: # X has not gradient
......@@ -258,8 +270,18 @@ class GSpMM_hetero(th.autograd.Function):
dY = gsddmm_hetero(gidx, 'mul', X_len, 'u', 'v', *tpl_X_dZ)
elif op in ['add', 'copy_rhs']:
dY = gsddmm_hetero(gidx, 'copy_rhs', X_len, 'u', 'v', *tpl_X_dZ)
dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if Y[i] is not None else None
for i in range(len(Y))])
else: # max/min
src_id, dst_id = gidx.metagraph.find_edge(0)
dY = tuple([th.zeros((Y_shape[i][0],) + dZ[dst_id].shape[1:], dtype=dtype, device=device)
if Y[i] is not None else None for i in range(len(Y))])
if op == 'mul':
grad = _expand(X, dZ.shape[1:]).gather(
0, argX.long()) * dZ
dY.scatter_add_(0, argY.long(), grad)
elif op in ['add', 'copy_rhs']:
dY = _update_grad_minmax_hetero(gidx.reverse(), op, dZ, argY, argY_etype, dY)
dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if dY[i] is not None else None
for i in range(len(dY))])
else: # Y has no gradient
dY = tuple([None] * len(Y))
return (None, None, None, None) + dX + dY
......@@ -273,7 +295,7 @@ def sddmm_cache_X(op, req_grad_X, req_grad_Y):
def sddmm_cache_Y(op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache Y in SDDMM forward stage."""
"""Rules to identify whether to cache Y in SDDMM forward stage."""
if op in ['mul', 'dot'] and req_grad_X:
return True
return False
......@@ -424,6 +446,7 @@ class GSDDMM_hetero(th.autograd.Function):
dY = tuple([None] * len(Y))
return (None, None, None, None, None) + dX + dY
class EdgeSoftmax(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
......
......@@ -4872,10 +4872,6 @@ class DGLHeteroGraph(object):
raise DGLError("User defined functions are not yet "
"supported in update_all for heterogeneous graphs. "
"Please use multi_update_all instead.")
if reduce_func.name in ['max', 'min']:
raise NotImplementedError("Reduce op \'" + reduce_func.name + "\' is not yet "
"supported in update_all for heterogeneous graphs. "
"Please use multi_update_all instead.")
if reduce_func.name in ['mean']:
raise NotImplementedError("Cannot set both intra-type and inter-type reduce "
"operators as 'mean' using update_all. Please use "
......
......@@ -87,7 +87,14 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
ret = gspmm_internal_hetero(g._graph, op,
'sum' if reduce_op == 'mean' else reduce_op,
len(lhs_data), *lhs_and_rhs_tuple)
# TODO (Israt): Add support for 'max', 'min', 'mean' in heterograph
# `update_all` on heterogeneous graphs replaces the inf values with zeros on
# the final output (after processing all etypes). `multi_update_all` performs
# this operation after processing each etype. It computes the final output based
# on the output of each etype where inf is already replaced by zero.
if reduce_op in ['min', 'max']:
ret = tuple([F.replace_inf_with_zero(ret[i]) if ret[i] is not None else None
for i in range(len(ret))])
# TODO (Israt): Add support for 'mean' in heterograph
# divide in degrees for mean reducer.
if reduce_op == 'mean':
ret_shape = F.shape(ret)
......
......@@ -190,7 +190,51 @@ def _gspmm(gidx, op, reduce_op, u, e):
def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
r""" Generalized Sparse Matrix Multiplication interface.
r""" Generalized Sparse Matrix Multiplication interface on heterogeneous graphs.
It handles multiple node and edge types of the graph. For each edge type, it takes
the result of :attr:`op` on source node feature and edge feature, and leads to a
message on edge. Then it aggregates the message by :attr:`reduce_op` on the destination
nodes of the etype.
.. math::
x_v = \psi_{(u, v, e)\in \mathcal{G}}(\rho(x_u, x_e))
where :math:`x_v` is the returned feature on destination nodes, and :math`x_u`,
:math:`x_e` refers to :attr:`u`, :attr:`e` respectively. :math:`\rho` means binary
operator :attr:`op` and :math:`\psi` means reduce operator :attr:`reduce_op`,
:math:`\mathcal{G}` is the graph we apply gspmm on: :attr:`g`.
Note that this function does not handle gradients.
Parameters
----------
gidx : HeteroGraphIndex
The input graph index.
op : str
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``, ``copy_lhs``,
``copy_rhs``.
reduce_op : str
Reduce operator, could be ``sum``, ``max``, ``min``.
u_len : int
The number of tensors in ``u`` (source node features)
u_and_e_tuple : Tuple of tensors
Tuple of source nodes' features and edges' features. ``u_and_e_tuple[:u_len]``
stores the source nodes's features of all source node types. ``u_and_e_tuple[u_len:]``
stores the edges's features of all the edge types.
The source nodes' features of the soruce node types could be None if op is ``copy_rhs``.
The edges' features of the edge types could be None if op is ``copy_lhs``.
Returns
-------
tuple
The returned tuple is composed of two elements:
- The first element refers to the tuple of result tensors.
- The second element refers to a tuple composed of arg_u and arg_e
(which is useful when reducer is `min`/`max`).
Notes
-----
This function does not handle gradients.
"""
u_tuple, e_tuple = u_and_e_tuple[:u_len], u_and_e_tuple[u_len:]
use_u = op != 'copy_rhs'
......@@ -199,11 +243,25 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
# deal with scalar features.
expand_u, expand_e = False, False
list_u = [None] * gidx.number_of_ntypes()
list_v = [None] * gidx.number_of_ntypes()
list_e = [None] * gidx.number_of_etypes()
num_ntypes = gidx.number_of_ntypes()
num_etypes = gidx.number_of_etypes()
list_u = [None] * num_ntypes
list_v = [None] * num_ntypes
list_e = [None] * num_etypes
list_arg_u_nd = [None] * num_ntypes
list_arg_u = [None] * num_ntypes
list_arg_u_ntype_nd = [None] * num_ntypes
list_arg_u_ntype = [None] * num_ntypes
# TODO(Israt): double check ntype or etype
list_arg_e_nd = [None] * num_ntypes
list_arg_e = [None] * num_ntypes
list_arg_e_etype_nd = [None] * num_ntypes
list_arg_e_etype = [None] * num_ntypes
for etid in range(gidx.number_of_etypes()):
use_cmp = reduce_op in ['max', 'min']
idtype = getattr(F, gidx.dtype)
for etid in range(num_etypes):
src_id, dst_id = gidx.metagraph.find_edge(etid)
u = u_tuple[src_id] if use_u else None
e = e_tuple[etid] if use_e else None
......@@ -224,29 +282,42 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
v_shp = (gidx.number_of_nodes(dst_id), ) +\
infer_broadcast_shape(op, u_shp[1:], e_shp[1:])
list_v[dst_id] = F.zeros(v_shp, dtype, ctx)
if use_cmp:
if use_u:
list_arg_u[dst_id] = F.zeros(v_shp, idtype, ctx)
list_arg_u_ntype[dst_id] = F.zeros(v_shp, idtype, ctx)
if use_e:
list_arg_e[dst_id] = F.zeros(v_shp, idtype, ctx)
list_arg_e_etype[dst_id] = F.zeros(v_shp, idtype, ctx)
list_arg_u_nd[dst_id] = to_dgl_nd_for_write(list_arg_u[dst_id])
list_arg_u_ntype_nd[dst_id] = to_dgl_nd_for_write(list_arg_u_ntype[dst_id])
list_arg_e_nd[dst_id] = to_dgl_nd_for_write(list_arg_e[dst_id])
list_arg_e_etype_nd[dst_id] = to_dgl_nd_for_write(list_arg_e_etype[dst_id])
use_cmp = reduce_op in ['max', 'min']
arg_u, arg_e = None, None
idtype = getattr(F, gidx.dtype)
if use_cmp:
if use_u:
arg_u = F.zeros(v_shp, idtype, ctx)
if use_e:
arg_e = F.zeros(v_shp, idtype, ctx)
arg_u_nd = to_dgl_nd_for_write(arg_u)
arg_e_nd = to_dgl_nd_for_write(arg_e)
if gidx.number_of_edges(0) > 0:
_CAPI_DGLKernelSpMMHetero(gidx, op, reduce_op,
[to_dgl_nd(u_i) for u_i in list_u],
[to_dgl_nd(e_i) for e_i in list_e],
[to_dgl_nd_for_write(v_i) for v_i in list_v],
arg_u_nd,
arg_e_nd)
arg_u = None if arg_u is None else F.zerocopy_from_dgl_ndarray(arg_u_nd)
arg_e = None if arg_e is None else F.zerocopy_from_dgl_ndarray(arg_e_nd)
list_arg_u_nd, list_arg_e_nd,
list_arg_u_ntype_nd, list_arg_e_etype_nd)
for l, arg_u_nd in enumerate(list_arg_u_nd):
# TODO(Israt): l or src_id as index of lhs
list_arg_u[l] = None if list_arg_u[l] is None else F.zerocopy_from_dgl_ndarray(arg_u_nd)
if expand_u and use_cmp:
list_arg_u[l] = F.squeeze(list_arg_u[l], -1)
for l, arg_e_nd in enumerate(list_arg_e_nd):
list_arg_e[l] = None if list_arg_e[l] is None else F.zerocopy_from_dgl_ndarray(arg_e_nd)
if expand_e and use_cmp:
list_arg_e[l] = F.squeeze(list_arg_e[l], -1)
for l, arg_u_ntype_nd in enumerate(list_arg_u_ntype_nd):
list_arg_u_ntype[l] = None if arg_u_ntype_nd is None \
else F.zerocopy_from_dgl_ndarray(arg_u_ntype_nd)
for l, arg_e_etype_nd in enumerate(list_arg_e_etype_nd):
list_arg_e_etype[l] = None if arg_e_etype_nd is None \
else F.zerocopy_from_dgl_ndarray(arg_e_etype_nd)
# To deal with scalar node/edge features.
for l in range(gidx.number_of_ntypes()):
for l in range(num_ntypes):
# replace None by empty tensor. Forward func doesn't accept None in tuple.
v = list_v[l]
v = F.tensor([]) if v is None else v
......@@ -254,12 +325,7 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
v = F.squeeze(v, -1) # To deal with scalar node/edge features.
list_v[l] = v
out = tuple(list_v)
if expand_u and use_cmp:
arg_u = F.squeeze(arg_u, -1)
if expand_e and use_cmp:
arg_e = F.squeeze(arg_e, -1)
return out, (arg_u, arg_e)
return out, (list_arg_u, list_arg_e, list_arg_u_ntype, list_arg_e_etype)
def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
......@@ -481,6 +547,50 @@ def _scatter_add(x, idx, m):
return out
def _update_grad_minmax_hetero(gidx, op, list_x, list_idx, list_idx_etype, list_dX):
r""" Update gradients for reduce operator max and min (on first dimension) implementation.
Parameters
----------
gidx : HeteroGraphIndex
The input graph index.
list_x : List of tensors
List of the input features.
list_idx : List of tensors
List of the indices array.
list_idx_etype : List of tensors
List of the node- or edge-type array.
list_dX : List of tensors
List of gradients.
Returns
-------
Tensor
The output tensor.
"""
use_u = op != 'copy_rhs'
use_e = op != 'copy_lhs'
list_out = [None] * len(list_dX)
for etid in range(gidx.number_of_etypes()):
src_id, dst_id = gidx.metagraph.find_edge(etid) # gidx is reveresed
x = list_x[src_id]
ctx = F.context(x)
dtype = F.dtype(x)
if use_u:
out_shp = (len(list_dX[dst_id]),) + F.shape(x)[1:]
list_out[dst_id] = F.zeros(out_shp, dtype, ctx)
if use_e:
out_shp = (len(list_dX[etid]),) + F.shape(x)[1:]
list_out[etid] = F.zeros(out_shp, dtype, ctx)
_CAPI_DGLKernelUpdateGradMinMaxHetero(gidx, op,
[to_dgl_nd(x) for x in list_x],
[to_dgl_nd(idx) for idx in list_idx],
[to_dgl_nd(idx_etype) for idx_etype in list_idx_etype],
[to_dgl_nd_for_write(out) for out in list_out])
return tuple(list_out)
def _bwd_segment_cmp(feat, arg, m):
r""" Backward phase of segment reduction (for 'min'/'max' reduction).
......
......@@ -50,6 +50,19 @@ void ScatterAdd(NDArray feat,
});
}
/*! \brief Update gradients for reduce operator max/min on heterogeneous graph.*/
template <int XPU, typename IdType, int bits>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
const std::string& op,
const std::vector<NDArray>& feat,
const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out) {
SWITCH_BITS(bits, DType, {
cpu::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
});
}
/*! \brief Backward function of segment cmp.*/
template <int XPU, typename IdType, int bits>
void BackwardSegmentCmp(
......@@ -121,6 +134,32 @@ template void ScatterAdd<kDLCPU, int64_t, 64>(
NDArray feat,
NDArray arg,
NDArray out);
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 32>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 32>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 64>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 64>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDLCPU, int32_t, 16>(
NDArray feat,
NDArray arg,
......
......@@ -8,6 +8,9 @@
#include <dgl/array.h>
#include <dgl/runtime/parallel_for.h>
#include <dgl/base_heterograph.h>
#include <vector>
#include <string>
namespace dgl {
namespace aten {
......@@ -101,6 +104,86 @@ void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
}
}
/*!
* \param graph The input heterogeneous graph.
* \param op The binary operator, could be `copy_u`, `copy_e'.
* \param list_feat List of the input tensors.
* \param list_idx List of the indices tensors.
* \param list_idx_etype List of the node- or edge-type tensors.
* \param list_out List of the output tensors.
*/
template <typename IdType, typename DType>
void UpdateGradMinMax_hetero(HeteroGraphPtr graph,
const std::string& op,
const std::vector<NDArray>& list_feat,
const std::vector<NDArray>& list_idx,
const std::vector<NDArray>& list_idx_ntypes,
std::vector<NDArray>* list_out) {
if (op == "copy_lhs") {
std::vector<std::vector<dgl_id_t>> dst_src_ntids(graph->NumVertexTypes(),
std::vector<dgl_id_t>());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_id_t dst_id = pair.first; // graph is reversed
const dgl_id_t src_id = pair.second;
dst_src_ntids[dst_id].push_back(src_id); // can have duplicates. Use Hashtable to optimize.
}
std::vector<bool> updated(graph->NumVertexTypes());
for (int dst_id = 0; dst_id < dst_src_ntids.size(); ++dst_id) {
std::fill(updated.begin(), updated.end(), false);
for (int j = 0; j < dst_src_ntids[dst_id].size(); ++j) {
int src_id = dst_src_ntids[dst_id][j];
if (updated[src_id]) continue;
const DType* feat_data = list_feat[dst_id].Ptr<DType>();
const IdType* idx_data = list_idx[dst_id].Ptr<IdType>();
const IdType* idx_ntype_data = list_idx_ntypes[dst_id].Ptr<IdType>();
DType* out_data = (*list_out)[src_id].Ptr<DType>();
int dim = 1;
for (int i = 1; i < (*list_out)[src_id]->ndim; ++i)
dim *= (*list_out)[src_id]->shape[i];
int n = list_feat[dst_id]->shape[0];
#pragma omp parallel for
for (int i = 0; i < n; ++i) {
for (int k = 0; k < dim; ++k) {
if (src_id == idx_ntype_data[i * dim + k]) {
const int write_row = idx_data[i * dim + k];
#pragma omp atomic
out_data[write_row * dim + k] += feat_data[i * dim + k]; // feat = dZ
}
}
}
updated[src_id] = true;
}
}
} else if (op == "copy_rhs") {
for (dgl_type_t etid = 0; etid < graph->NumEdgeTypes(); ++etid) {
auto pair = graph->meta_graph()->FindEdge(etid);
const dgl_id_t dst_id = pair.first; // graph is reversed
const dgl_id_t src_id = pair.second;
const DType* feat_data = list_feat[dst_id].Ptr<DType>();
const IdType* idx_data = list_idx[dst_id].Ptr<IdType>();
const IdType* idx_ntype_data = list_idx_ntypes[dst_id].Ptr<IdType>();
DType* out_data = (*list_out)[etid].Ptr<DType>();
int dim = 1;
for (int i = 1; i < (*list_out)[etid]->ndim; ++i)
dim *= (*list_out)[etid]->shape[i];
int n = list_feat[dst_id]->shape[0];
#pragma omp parallel for
for (int i = 0; i < n; ++i) {
for (int k = 0; k < dim; ++k) {
if (etid == idx_ntype_data[i * dim + k]) {
const int write_row = idx_data[i * dim + k];
#pragma omp atomic
out_data[write_row * dim + k] += feat_data[i * dim + k]; // feat = dZ
}
}
}
}
} else {
LOG(FATAL) << "Unsupported binary operator: " << op;
}
}
/*!
* \brief CPU kernel of backward phase of segment min/max.
* \note math equation: out[arg[i, k], k] = feat[i, k]
......
......@@ -54,8 +54,8 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_ufeat,
const std::vector<NDArray>& vec_efeat,
std::vector<NDArray> vec_out,
const std::vector<NDArray>& out_aux,
std::vector<NDArray>* vec_out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids) {
const int64_t dim = bcast.out_len;
......@@ -69,7 +69,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
CSRMatrix csr = vec_csr[etype];
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray out = vec_out[dst_id];
NDArray out = (*vec_out)[dst_id];
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
}
});
......@@ -77,23 +77,44 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
} else if (reduce == "max" || reduce == "min") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
std::vector<bool> updated((*vec_out).size(), false);
// TODO(Israt): use vector updated to fill(out...) too
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
if (reduce == "max")
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Max<DType>::zero);
else
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Min<DType>::zero);
const dgl_type_t dst_id = out_node_tids[etype];
if (!updated[dst_id]) {
updated[dst_id] = true;
if (Op::use_lhs) {
IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
std::fill(argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1);
}
if (Op::use_rhs) {
IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
std::fill(arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1);
}
}
}
/* Call SpMM for each relation type */
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
const dgl_type_t src_id = ufeat_node_tids[etype];
const dgl_type_t dst_id = out_node_tids[etype];
CSRMatrix csr = vec_csr[etype];
DType *out_off = vec_out[out_node_tids[etype]].Ptr<DType>();
DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray out = vec_out[dst_id];
NDArray out = (*vec_out)[dst_id];
if (reduce == "max") {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Max<DType>>(
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
(*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
} else {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
(*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
}
}
});
......@@ -132,42 +153,42 @@ template void SpMMCsrHetero<kDLCPU, int32_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
......
......@@ -309,6 +309,93 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
#endif // _WIN32
}
/*!
* \brief CPU kernel of SpMM-Min/Max on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param ufeat The feature on source nodes.
* \param efeat The feature on edges.
* \param out The result feature on destination nodes.
* \param argu Arg-Min/Max on source nodes, which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max
* reducer. \param arge Arg-Min/Max on edges. which refers the source node
* indices correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max
* reducer. \note It uses node parallel strategy, different threads are
* responsible for the computation of different nodes. \note The result will
* contain infinity for zero-degree nodes.
*/
template <typename IdType, typename DType, typename Op, typename Cmp>
void SpMMCmpCsrHetero(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
NDArray efeat, NDArray out, NDArray argu, NDArray arge,
NDArray argu_ntype, NDArray arge_etype,
const int ntype, const int etype) {
const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
const IdType* indices = static_cast<IdType*>(csr.indices->data);
const IdType* edges =
has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
const DType* X = Op::use_lhs ? static_cast<DType*>(ufeat->data) : nullptr;
const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
rhs_dim = bcast.rhs_len;
DType* O = static_cast<DType*>(out->data);
IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;
IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;
IdType* argX_ntype = Op::use_lhs ? static_cast<IdType*>(argu_ntype->data) : nullptr;
IdType* argW_etype = Op::use_rhs ? static_cast<IdType*>(arge_etype->data) : nullptr;
CHECK_NOTNULL(indptr);
CHECK_NOTNULL(O);
if (Op::use_lhs) {
CHECK_NOTNULL(indices);
CHECK_NOTNULL(X);
CHECK_NOTNULL(argX);
}
if (Op::use_rhs) {
if (has_idx)
CHECK_NOTNULL(edges);
CHECK_NOTNULL(W);
CHECK_NOTNULL(argW);
}
// TODO(Israt): Use LIBXSMM. Homogeneous graph uses LIBXMM when enabled.
runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
for (auto rid = b; rid < e; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim;
IdType* argx_off = argX + rid * dim;
IdType* argw_off = argW + rid * dim;
IdType* argx_ntype = argX_ntype + rid * dim;
IdType* argw_etype = argW_etype + rid * dim;
for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j];
const IdType eid = has_idx ? edges[j] : j;
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off =
Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off =
Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
const DType val = Op::Call(lhs_off, rhs_off);
if (Cmp::Call(out_off[k], val)) {
out_off[k] = val;
if (Op::use_lhs) {
argx_off[k] = cid;
argx_ntype[k] = ntype;
}
if (Op::use_rhs) {
argw_off[k] = eid;
argw_etype[k] = etype;
}
}
}
}
}
});
}
/*!
* \brief CPU kernel of SpMM-Min/Max on Coo format.
* \param bcast Broadcast information.
......
......@@ -4,10 +4,12 @@
* \brief Segment reduce C APIs and definitions.
*/
#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include "./segment_reduce.cuh"
#include "./functor.cuh"
#include "./utils.h"
namespace dgl {
using namespace cuda;
......@@ -48,6 +50,19 @@ void ScatterAdd(NDArray feat,
}
template <int XPU, typename IdType, int bits>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
const std::string& op,
const std::vector<NDArray>& feat,
const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out) {
SWITCH_BITS(bits, DType, {
LOG(FATAL) << "Not implemented. Please use CPU version.";
});
}
template <int XPU, typename IdType, int bits>
void BackwardSegmentCmp(NDArray feat,
NDArray arg,
......@@ -118,6 +133,32 @@ template void ScatterAdd<kDLGPU, int64_t, 64>(
NDArray feat,
NDArray idx,
NDArray out);
template void UpdateGradMinMax_hetero<kDLGPU, int32_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int64_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int32_t, 32>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int64_t, 32>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int32_t, 64>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int64_t, 64>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDLGPU, int32_t, 16>(
NDArray feat,
NDArray arg,
......
......@@ -515,15 +515,15 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_ufeat,
const std::vector<NDArray>& vec_efeat,
std::vector<NDArray> vec_out,
const std::vector<NDArray>& out_aux,
std::vector<NDArray>* vec_out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, // ufeat node type id
const std::vector<dgl_type_t>& out_ntids) { // output node type id
bool is_scalar_efeat = vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];
bool use_efeat = op != "copy_lhs";
auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
SWITCH_BITS(bits, DType, {
std::vector<DType*> trans_out(vec_out.size(), NULL);
std::vector<DType*> trans_out((*vec_out).size(), NULL);
bool use_legacy_cusparsemm =
(CUDART_VERSION < 11000) &&
......@@ -532,9 +532,9 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
(op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>(false)));
// Create temporary output buffer to store non-transposed output
if (use_legacy_cusparsemm) {
for (dgl_type_t ntype = 0; ntype < vec_out.size(); ++ntype) {
const int m = vec_out[ntype]->shape[0];
const int n = vec_out[ntype]->shape[1];
for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
const int m = (*vec_out)[ntype]->shape[0];
const int n = (*vec_out)[ntype]->shape[1];
if (m == 0) continue;
DType *out = static_cast<DType*>(device->AllocWorkspace(vec_csr[0].indptr->ctx,
m * n * sizeof(DType)));
......@@ -577,7 +577,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
if (op == "copy_lhs" && cusparse_available<bits, IdType>(more_nnz)) { // cusparse
/* If CUDA is less than 11.0, put the output in trans_out for later transposition */
DType *out = (CUDART_VERSION < 11000) ? trans_out[dst_id] :
static_cast<DType*>(vec_out[dst_id]->data);
static_cast<DType*>((*vec_out)[dst_id]->data);
cusparse::CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data),
......@@ -593,8 +593,8 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data),
static_cast<DType*>(efeat->data),
// TODO(Israt): Change vec_out to trans_out to support CUDA version < 11
static_cast<DType*>(vec_out[dst_id]->data),
// TODO(Israt): Change (*vec_out) to trans_out to support CUDA version < 11
static_cast<DType*>((*vec_out)[dst_id]->data),
x_length, thr_entry->stream);
} else { // general kernel
NDArray ufeat = (vec_ufeat.size() == 0) ?
......@@ -603,27 +603,10 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
NullArray() : vec_efeat[etype];
SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
bcast, csr, ufeat, efeat, vec_out[dst_id], NullArray(), NullArray());
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(), NullArray());
});
}
} else if (reduce == "max") {
SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
bcast, csr, ufeat, efeat, vec_out[dst_id], out_aux[0], out_aux[1]);
});
} else if (reduce == "min") {
SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
bcast, csr, ufeat, efeat, vec_out[dst_id], out_aux[0], out_aux[1]);
});
// TODO(Israt): Add support for max/min reducer
} else {
LOG(FATAL) << "Not implemented";
}
......@@ -631,11 +614,11 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
if (use_legacy_cusparsemm) {
// transpose output
for (dgl_type_t ntype = 0; ntype < vec_out.size(); ++ntype) {
const int m = vec_out[ntype]->shape[0];
const int n = vec_out[ntype]->shape[1];
for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
const int m = (*vec_out)[ntype]->shape[0];
const int n = (*vec_out)[ntype]->shape[1];
if (m == 0) continue;
DType *C_data = static_cast<DType*>(vec_out[ntype]->data);
DType *C_data = static_cast<DType*>((*vec_out)[ntype]->data);
_Transpose(trans_out[ntype], C_data, n, m);
device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]);
}
......@@ -709,37 +692,37 @@ template void SpMMCsrHetero<kDLGPU, int32_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int32_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCoo<kDLGPU, int32_t, 16>(
......
......@@ -55,30 +55,35 @@ void SpMM(const std::string& op, const std::string& reduce,
/*! \brief Generalized Sparse Matrix-Matrix Multiplication with hetero-graph support. */
void SpMMHetero(const std::string& op, const std::string& reduce,
HeteroGraphPtr graph,
std::vector<NDArray> ufeat_vec,
std::vector<NDArray> efeat_vec,
std::vector<NDArray> out,
std::vector<NDArray> out_aux) {
const std::vector<NDArray>& ufeat_vec,
const std::vector<NDArray>& efeat_vec,
std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux) {
SparseFormat format = graph->SelectFormat(0, CSC_CODE);
std::vector<CSRMatrix> vec_graph;
std::vector<dgl_type_t> ufeat_eid;
std::vector<dgl_type_t> efeat_eid;
std::vector<dgl_type_t> out_eid;
auto pair = graph->meta_graph()->FindEdge(0); // first etype
NDArray ufeat_etype0 = (ufeat_vec.size() == 0) ? NullArray() : ufeat_vec[pair.first];
NDArray efeat_etype0 = (efeat_vec.size() == 0) ? NullArray() : efeat_vec[0];
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
vec_graph.push_back(graph->GetCSCMatrix(etype));
auto pair = graph->meta_graph()->FindEdge(etype);
ufeat_eid.push_back(pair.first);
efeat_eid.push_back(etype);
out_eid.push_back(pair.second);
if (ufeat_etype0->shape[1] != ufeat_vec[pair.first]->shape[1])
LOG(FATAL) << "Column width of the input node features of all etypes must be same.";
if (efeat_etype0->shape[1] != efeat_vec[etype]->shape[1])
LOG(FATAL) << "Column width of the input edge features of all etypes must be same.";
}
NDArray efeat = (efeat_vec.size() == 0) ? NullArray() : efeat_vec[efeat_eid[0]];
NDArray ufeat = (ufeat_vec.size() == 0) ? NullArray() : ufeat_vec[ufeat_eid[0]];
const auto& bcast = CalcBcastOff(op, ufeat, efeat);
const auto& bcast = CalcBcastOff(op, ufeat_etype0, efeat_etype0);
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out[out_eid[0]]->dtype, bits, "Feature data", {
ATEN_FLOAT_BITS_SWITCH((*out)[out_eid[0]]->dtype, bits, "Feature data", {
if (format == SparseFormat::kCSC) {
SpMMCsrHetero<XPU, IdType, bits>(
op, reduce, bcast, vec_graph,
......@@ -226,6 +231,24 @@ void ScatterAddDispatch(NDArray feat, NDArray idx, NDArray out) {
});
}
/*! \brief Update gradients (reduce op max/min) dispatch function on heterogeneous graph. */
void UpdateGradMinMaxDispatchHetero(const HeteroGraphPtr& graph,
const std::string& op,
const std::vector<NDArray>& feat,
const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out) {
auto pair = graph->meta_graph()->FindEdge(0); // checking the first etype
auto src_id = pair.first;
ATEN_XPU_SWITCH_CUDA(feat[src_id]->ctx.device_type, XPU, "ScatterAdd", {
ATEN_ID_TYPE_SWITCH(idx[src_id]->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(feat[src_id]->dtype, bits, "Feature data", {
UpdateGradMinMax_hetero<XPU, IdType, bits>(graph, op, feat, idx, idx_etype, out);
});
});
});
}
/*! \brief Backward segment cmp dispatch function.*/
void BackwardSegmentCmpDispatch(NDArray feat, NDArray arg, NDArray out) {
ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "BackwardSegmentCmp", {
......@@ -333,35 +356,33 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMMHetero")
List<Value> list_U = args[3];
List<Value> list_E = args[4];
List<Value> list_V = args[5];
NDArray ArgU = args[6];
NDArray ArgE = args[7];
std::vector<NDArray> U_vec;
std::vector<NDArray> V_vec;
std::vector<NDArray> E_vec;
U_vec.reserve(list_U.size());
V_vec.reserve(list_V.size());
E_vec.reserve(list_E.size());
for (Value val : list_U) {
U_vec.push_back(val->data);
}
for (Value val : list_V) {
V_vec.push_back(val->data);
}
for (Value val : list_E) {
E_vec.push_back(val->data);
List<Value> list_ArgU = args[6];
List<Value> list_ArgE = args[7];
List<Value> list_ArgU_ntype = args[8];
List<Value> list_ArgE_etype = args[9];
std::vector<std::vector<NDArray>> Arg_vec; // ArgU + ArgE
for (int i = 0; i < 4; ++i) { // ArgU + ArgE + ArgU_ntype + ArgE_etype
Arg_vec.push_back(std::vector<NDArray>());
}
std::vector<NDArray> U_vec = ListValueToVector<NDArray>(list_U);
std::vector<NDArray> V_vec = ListValueToVector<NDArray>(list_V);
std::vector<NDArray> E_vec = ListValueToVector<NDArray>(list_E);
Arg_vec[0] = ListValueToVector<NDArray>(list_ArgU);
Arg_vec[1] = ListValueToVector<NDArray>(list_ArgE);
Arg_vec[2] = ListValueToVector<NDArray>(list_ArgU_ntype);
Arg_vec[3] = ListValueToVector<NDArray>(list_ArgE_etype);
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_id_t src_id = pair.first;
const dgl_id_t dst_id = pair.second;
NDArray U = (U_vec.size() == 0) ? NullArray() : U_vec[src_id];
NDArray E = (E_vec.size() == 0) ? NullArray() : E_vec[etype];
CheckCtx(graph->Context(), {U, E, V_vec[dst_id], ArgU, ArgE},
CheckCtx(graph->Context(), {U, E, V_vec[dst_id], Arg_vec[0][dst_id], Arg_vec[1][dst_id]},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"});
CheckContiguous({U, E, V_vec[dst_id], ArgU, ArgE},
CheckContiguous({U, E, V_vec[dst_id], Arg_vec[0][dst_id], Arg_vec[1][dst_id]},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"});
}
SpMMHetero(op, reduce_op, graph.sptr(), U_vec, E_vec, V_vec, {ArgU, ArgE});
SpMMHetero(op, reduce_op, graph.sptr(), U_vec, E_vec, &V_vec, &Arg_vec);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM")
......@@ -440,6 +461,23 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelScatterAdd")
ScatterAddDispatch(feat, idx, out);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelUpdateGradMinMaxHetero")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
List<Value> list_feat = args[2];
List<Value> list_idx = args[3];
List<Value> list_idx_etype = args[4];
List<Value> list_out = args[5];
std::vector<NDArray> vec_feat = ListValueToVector<NDArray>(list_feat);
std::vector<NDArray> vec_idx = ListValueToVector<NDArray>(list_idx);
std::vector<NDArray> vec_idx_etype = ListValueToVector<NDArray>(list_idx_etype);
std::vector<NDArray> vec_out = ListValueToVector<NDArray>(list_out);
// CheckCtx(feat->ctx, {feat, idx, out}, {"feat", "idx", "out"});
// CheckContiguous({feat, idx, out}, {"feat", "idx", "out"});
UpdateGradMinMaxDispatchHetero(graph.sptr(), op, vec_feat, vec_idx, vec_idx_etype, &vec_out);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelBwdSegmentCmp")
.set_body([](DGLArgs args, DGLRetValue *rv) {
NDArray feat = args[0];
......
......@@ -39,8 +39,8 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat,
std::vector<NDArray> out,
const std::vector<NDArray>& out_aux,
std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_eid,
const std::vector<dgl_type_t>& out_eid);
/*!
......@@ -130,6 +130,17 @@ void ScatterAdd(NDArray feat,
NDArray idx,
NDArray out);
/*!
* \brief Update gradients for reduce operator max and min on first dimension.
*/
template <int XPU, typename IdType, int bits>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
const std::string& op,
const std::vector<NDArray>& feat,
const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out);
/*!
* \brief Backward function of segment cmp.
*/
......
......@@ -12,12 +12,12 @@ from dgl import DGLError
import test_utils
from test_utils import parametrize_dtype, get_cases
from scipy.sparse import rand
rfuncs = {'sum': fn.sum, 'max': fn.max, 'min': fn.min, 'mean': fn.mean}
fill_value = {'sum': 0, 'max': float("-inf")}
feat_size = 2
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@unittest.skipIf(F._default_context_str == 'gpu', reason="Max/min reducer not supported on GPU yet.")
def create_test_heterograph(idtype):
# test heterograph from the docstring, plus a user -- wishes -- game relation
......@@ -38,15 +38,44 @@ def create_test_heterograph(idtype):
assert g.device == F.ctx()
return g
def create_test_heterograph_2(idtype):
src = np.random.randint(0, 5, 25)
dst = np.random.randint(0, 5, 25)
g = dgl.heterograph({
('user', 'becomes', 'player'): (src, dst),
('user', 'follows', 'user'): (src, dst),
('user', 'plays', 'game'): (src, dst),
('user', 'wishes', 'game'): (src, dst),
('developer', 'develops', 'game'): (src, dst),
}, idtype=idtype, device=F.ctx())
assert g.idtype == idtype
assert g.device == F.ctx()
return g
def create_test_heterograph_large(idtype):
src = np.random.randint(0, 50, 2500)
dst = np.random.randint(0, 50, 2500)
g = dgl.heterograph({
('user', 'follows', 'user'): (src, dst),
('user', 'plays', 'game'): (src, dst),
('user', 'wishes', 'game'): (src, dst),
('developer', 'develops', 'game'): (src, dst),
}, idtype=idtype, device=F.ctx())
assert g.idtype == idtype
assert g.device == F.ctx()
return g
@parametrize_dtype
def test_unary_copy_u(idtype):
def _test(mfunc, rfunc):
g = create_test_heterograph(idtype)
g = create_test_heterograph_2(idtype)
g0 = create_test_heterograph(idtype)
g1 = create_test_heterograph_large(idtype)
cross_reducer = rfunc.__name__
x1 = F.randn((g.num_nodes('user'), feat_size))
x2 = F.randn((g.num_nodes('developer'), feat_size))
F.attach_grad(x1)
F.attach_grad(x2)
g.nodes['user'].data['h'] = x1
......@@ -58,56 +87,63 @@ def test_unary_copy_u(idtype):
with F.record_grad():
g.multi_update_all(
{'plays' : (mfunc('h', 'm'), rfunc('m', 'y')),
'follows': (mfunc('h', 'm'), rfunc('m', 'y')),
'develops': (mfunc('h', 'm'), rfunc('m', 'y')),
'wishes': (mfunc('h', 'm'), rfunc('m', 'y'))},
'sum')
r1 = g.nodes['game'].data['y']
F.backward(r1, F.randn(r1.shape))
n_grad1 = F.grad(g.nodes['user'].data['h'])
g.nodes['game'].data.clear()
{etype : (mfunc('h', 'm'), rfunc('m', 'y'))
for etype in g.canonical_etypes},
cross_reducer)
r1 = g.nodes['game'].data['y'].clone()
r2 = g.nodes['user'].data['y'].clone()
r3 = g.nodes['player'].data['y'].clone()
loss = r1.sum() + r2.sum() + r3.sum()
F.backward(loss)
n_grad1 = F.grad(g.nodes['user'].data['h']).clone()
n_grad2 = F.grad(g.nodes['developer'].data['h']).clone()
g.nodes['user'].data.clear()
g.nodes['developer'].data.clear()
g.nodes['game'].data.clear()
g.nodes['player'].data.clear()
#################################################################
# update_all(): call msg_passing for all etypes
#################################################################
g.update_all(mfunc('h', 'm'), rfunc('m', 'y'))
r2 = g.nodes['game'].data['y']
F.backward(r2, F.randn(r2.shape))
n_grad2 = F.grad(g.nodes['user'].data['h'])
# correctness check
def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y))
if not F.allclose(r1, r2):
_print_error(r1, r2)
assert F.allclose(r1, r2)
if not F.allclose(n_grad1, n_grad2):
print('node grad')
_print_error(n_grad1, n_grad2)
assert(F.allclose(n_grad1, n_grad2))
F.attach_grad(x1)
F.attach_grad(x2)
g.nodes['user'].data['h'] = x1
g.nodes['developer'].data['h'] = x2
with F.record_grad():
g.update_all(mfunc('h', 'm'), rfunc('m', 'y'))
r4 = g.nodes['game'].data['y']
r5 = g.nodes['user'].data['y']
r6 = g.nodes['player'].data['y']
loss = r4.sum() + r5.sum() + r6.sum()
F.backward(loss)
n_grad3 = F.grad(g.nodes['user'].data['h'])
n_grad4 = F.grad(g.nodes['developer'].data['h'])
assert F.allclose(r1, r4)
assert F.allclose(r2, r5)
assert F.allclose(r3, r6)
assert(F.allclose(n_grad1, n_grad3))
assert(F.allclose(n_grad2, n_grad4))
_test(fn.copy_u, fn.sum)
# TODO(Israt) :Add reduce func to suport the following reduce op
# _test('copy_u', 'max')
# _test('copy_u', 'min')
_test(fn.copy_u, fn.max)
_test(fn.copy_u, fn.min)
# _test('copy_u', 'mean')
@parametrize_dtype
def test_unary_copy_e(idtype):
def _test(mfunc, rfunc):
g = create_test_heterograph(idtype)
feat_size = 2
x1 = F.randn((4,feat_size))
x2 = F.randn((4,feat_size))
x3 = F.randn((3,feat_size))
x4 = F.randn((3,feat_size))
g = create_test_heterograph_large(idtype)
g0 = create_test_heterograph_2(idtype)
g1 = create_test_heterograph(idtype)
cross_reducer = rfunc.__name__
x1 = F.randn((g.num_edges('plays'),feat_size))
x2 = F.randn((g.num_edges('follows'),feat_size))
x3 = F.randn((g.num_edges('develops'),feat_size))
x4 = F.randn((g.num_edges('wishes'),feat_size))
F.attach_grad(x1)
F.attach_grad(x2)
F.attach_grad(x3)
......@@ -127,44 +163,60 @@ def test_unary_copy_e(idtype):
'follows': (mfunc('eid', 'm'), rfunc('m', 'y')),
'develops': (mfunc('eid', 'm'), rfunc('m', 'y')),
'wishes': (mfunc('eid', 'm'), rfunc('m', 'y'))},
'sum')
r1 = g.nodes['game'].data['y']
F.backward(r1, F.randn(r1.shape))
e_grad1 = F.grad(g['develops'].edata['eid'])
cross_reducer)
r1 = g.nodes['game'].data['y'].clone()
r2 = g.nodes['user'].data['y'].clone()
loss = r1.sum() + r2.sum()
F.backward(loss)
e_grad1 = F.grad(g['develops'].edata['eid']).clone()
e_grad2 = F.grad(g['plays'].edata['eid']).clone()
e_grad3 = F.grad(g['wishes'].edata['eid']).clone()
e_grad4 = F.grad(g['follows'].edata['eid']).clone()
{etype : (g[etype].edata.clear())
for _, etype, _ in g.canonical_etypes},
#################################################################
# update_all(): call msg_passing for all etypes
#################################################################
# TODO(Israt): output type can be None in multi_update and empty
# tensor in new_update_all
g.update_all(mfunc('eid', 'm'), rfunc('m', 'y'))
r2 = g.nodes['game'].data['y']
F.backward(r2, F.randn(r2.shape))
e_grad2 = F.grad(g['develops'].edata['eid'])
F.attach_grad(x1)
F.attach_grad(x2)
F.attach_grad(x3)
F.attach_grad(x4)
g['plays'].edata['eid'] = x1
g['follows'].edata['eid'] = x2
g['develops'].edata['eid'] = x3
g['wishes'].edata['eid'] = x4
with F.record_grad():
g.update_all(mfunc('eid', 'm'), rfunc('m', 'y'))
r3 = g.nodes['game'].data['y']
r4 = g.nodes['user'].data['y']
loss = r3.sum() + r4.sum()
F.backward(loss)
e_grad5 = F.grad(g['develops'].edata['eid'])
e_grad6 = F.grad(g['plays'].edata['eid'])
e_grad7 = F.grad(g['wishes'].edata['eid'])
e_grad8 = F.grad(g['follows'].edata['eid'])
# # correctness check
def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y))
if not F.allclose(r1, r2):
_print_error(r1, r2)
assert F.allclose(r1, r2)
if not F.allclose(e_grad1, e_grad2):
print('edge grad')
_print_error(e_grad1, e_grad2)
assert(F.allclose(e_grad1, e_grad2))
assert F.allclose(r1, r3)
assert F.allclose(r2, r4)
assert(F.allclose(e_grad1, e_grad5))
assert(F.allclose(e_grad2, e_grad6))
assert(F.allclose(e_grad3, e_grad7))
assert(F.allclose(e_grad4, e_grad8))
_test(fn.copy_e, fn.sum)
# TODO(Israt) :Add reduce func to suport the following reduce op
# _test('copy_e', 'max')
# _test('copy_e', 'min')
_test(fn.copy_e, fn.max)
_test(fn.copy_e, fn.min)
# _test('copy_e', 'mean')
@parametrize_dtype
def test_binary_op(idtype):
def _test(lhs, rhs, binary_op, reducer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment