Unverified Commit 071cba1f authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[kernel] New SpMM & SDDMM kernel on CPU and CUDA (#1644)

* udp

* simplify

* sddmm dot cpu

* upd

* format

* upd

* compatible with MJ's PR

* lint

* upd

* upd

* upd

* python end

* upd

* upd

* lint

* lint

* upd

* upd

* upd

* upd

* upd

* lint

* fix mxnet

* upd

* lint

* use minjie's ptr

* macro

* upd

* reorg

* lint

* fix corner cases

* upd

* enrich cpu docs

* upd

* upd

* lint

* lint

* pylint

* sx review

* improve docstring

* python doc

* upd

* restructure

* lint

* upd test

* upd

* pylint

* fix corner cases and test
parent 90a103e7
...@@ -25,7 +25,6 @@ struct CSRMatrix; ...@@ -25,7 +25,6 @@ struct CSRMatrix;
* \brief Plain COO structure * \brief Plain COO structure
* *
* The data array stores integer ids for reading edge features. * The data array stores integer ids for reading edge features.
* Note that we do allow duplicate non-zero entries -- multiple non-zero entries * Note that we do allow duplicate non-zero entries -- multiple non-zero entries
* that have the same row, col indices. It corresponds to multigraph in * that have the same row, col indices. It corresponds to multigraph in
* graph terminology. * graph terminology.
......
/*!
* Copyright (c) 2020 by Contributors * \file dgl/aten/bcast.h
* \brief Broadcast related function C++ header.
*/
#ifndef DGL_BCAST_H_
#define DGL_BCAST_H_
#include <string>
#include <vector>
#include "./runtime/ndarray.h"
using namespace dgl::runtime;
namespace dgl {
/*!
* \brief Broadcast offsets and auxiliary information.
*/
struct BcastOff {
/*!
* \brief offset vector of lhs operand and rhs operand.
* \note lhs_offset[i] indicates the start position of the scalar
* in lhs operand that required to compute the i-th element
* in the output, likewise for rhs_offset.
*
* For example, when lhs array has shape (1, 3) and rhs array
* has shape (5, 1), the resulting array would have shape (5, 3),
* then both lhs_offset and rhs_offset would contain 15 elements.
*
* lhs_offset: 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2
* rhs_offset: 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4
*
* in order to compute the 7-th (row 2, column 0) element in the output,
* we need the 0-th element in the lhs array and the 2-th element in the
* rhs array.
*/
std::vector<int64_t> lhs_offset, rhs_offset;
/*! \brief Whether broadcast is required or not. */
bool use_bcast;
/*!
* \brief Auxiliary information for kernel computation
* \note lhs_len refers to the left hand side operand length.
* e.g. 15 for shape (1, 3, 5)
* rhs_len refers to the right hand side operand length.
* e.g. 15 for shape (3, 1, 5)
* out_len refers to the output length.
* e.g. 45 for shape (3, 3, 5)
* reduce_size refers to the reduction size (for op like dot).
* e.g. 1 for add, 5 for dot and lhs_shape,rhs_shape=(3,5)
*/
int64_t lhs_len, rhs_len, out_len, reduce_size;
};
/*!
* \brief: Compute broadcast and auxiliary information given operator
* and operands for kernel computation.
* \param op: a string indicates the operator, could be `add`, `sub`,
* `mul`, `div`, `dot`, 'copy_u`, `copy_e`.
* \param lhs The left hand side operand of NDArray class.
* \param rhs The right hand side operand of NDArray class.
* \return the broadcast information of BcastOff class.
*/
BcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs);
} // namespace dgl
#endif // DGL_BCAST_H_
/*!
* Copyright (c) 2020 by Contributors
* \file dgl/aten/kernel.h
* \brief Sparse matrix operators.
*/
#ifndef DGL_KERNEL_H_
#define DGL_KERNEL_H_
#include <string>
#include <vector>
#include "array.h"
#include "./bcast.h"
#include "./base_heterograph.h"
namespace dgl {
namespace aten {
/*!
* \brief Generalized Sparse Matrix-Matrix Multiplication.
* \param op The binary operator, could be `add`, `sub', `mul`, 'div',
* `copy_u`, `copy_e'.
* \param op The reduce operator, could be `sum`, `min`, `max'.
* \param graph The graph we apply SpMM on.
* \param ufeat The source node feature.
* \param efeat The edge feature.
* \param out The output feature on destination nodes.
* \param out_aux A list of NDArray's that contains auxiliary information such
* as the argmax on source nodes and edges for reduce operators such as
* `min` and `max`.
* \param format The format of sparse matrix.
*/
void SpMM(const std::string& op, const std::string& reduce,
HeteroGraphPtr graph,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux,
SparseFormat format = SparseFormat::kAny);
/*!
* \brief Generalized Sampled Dense-Dense Matrix Multiplication.
* \param op The binary operator, could be `add`, `sub', `mul`, 'div',
* `dot`, `copy_u`, `copy_e'.
* \param graph The graph we apply SpMM on.
* \param ufeat The source node feature.
* \param vfeat The destination node feature.
* \param out The output feature on edge.
* \param format The format of sparse matrix.
*/
void SDDMM(const std::string& op,
HeteroGraphPtr graph,
NDArray ufeat,
NDArray efeat,
NDArray out,
SparseFormat format = SparseFormat::kAny);
} // namespace aten
} // namespace dgl
#endif // DGL_KERNEL_H_
...@@ -30,3 +30,4 @@ from .traversal import * ...@@ -30,3 +30,4 @@ from .traversal import *
from .transform import * from .transform import *
from .propagate import * from .propagate import *
from .udf import NodeBatch, EdgeBatch from .udf import NodeBatch, EdgeBatch
from .sparse import gspmm, gsddmm
...@@ -4,10 +4,10 @@ from . import backend as F ...@@ -4,10 +4,10 @@ from . import backend as F
from . import convert from . import convert
from . import random from . import random
__all__ = ['rand_graph'] __all__ = ['rand_graph', 'rand_bipartite']
def rand_graph(num_nodes, num_edges, restrict_format='any'): def rand_graph(num_nodes, num_edges, restrict_format='any'):
"""Generate a random graph of the given number of edges. """Generate a random graph of the given number of nodes/edges.
It uniformly chooses ``num_edges`` from all pairs and form a graph. It uniformly chooses ``num_edges`` from all pairs and form a graph.
...@@ -34,3 +34,33 @@ def rand_graph(num_nodes, num_edges, restrict_format='any'): ...@@ -34,3 +34,33 @@ def rand_graph(num_nodes, num_edges, restrict_format='any'):
num_nodes=num_nodes, validate=False, num_nodes=num_nodes, validate=False,
restrict_format=restrict_format) restrict_format=restrict_format)
return g return g
def rand_bipartite(num_src_nodes, num_dst_nodes, num_edges, restrict_format='any'):
"""Generate a random bipartite graph of the given number of src/dst nodes and
number of edges.
It uniformly chooses ``num_edges`` from all pairs and form a graph.
Parameters
----------
num_src_nodes : int
The number of source nodes, the :math:`|U|` in :math:`G=(U,V,E)`.
num_dst_nodes : int
The number of destination nodes, the :math:`|V|` in :math:`G=(U,V,E)`.
num_edges : int
The number of edges
restrict_format : 'any', 'coo', 'csr', 'csc', optional
Force the storage format. Default: 'any' (i.e. let DGL decide what to use).
Returns
-------
DGLHeteroGraph
Generated random bipartite graph.
"""
eids = random.choice(num_src_nodes * num_dst_nodes, num_edges, replace=False)
rows = F.astype(eids / num_dst_nodes, F.dtype(eids))
cols = F.astype(eids % num_dst_nodes, F.dtype(eids))
g = convert.bipartite((rows, cols),
num_nodes=(num_src_nodes, num_dst_nodes), validate=False,
restrict_format=restrict_format)
return g
"""Module for sparse matrix operators."""
# pylint: disable= invalid-name
import dgl.ndarray as nd
from ._ffi.function import _init_api
from .base import DGLError
from .utils import to_dgl_context
from . import backend as F
def infer_broadcast_shape(op, shp1, shp2):
r"""Check the shape validity, and infer the output shape given input shape and operator.
Note the both :attr:`shp1`, :attr:`shp2` and the returned shape are feature
shapes (i.e. we remove the first dimension, which correspond to graph statistics
such as number of nodes, number of edges, etc.).
We allow applying op on operands with different shapes, according to the
broadcasting semantics of Numpy/Scipy:
https://numpy.org/doc/stable/user/basics.broadcasting.html
Parameters
----------
op : str
The binary op's name, could be `add`, `sub`, `mul`, `div`, `dot`, `copy_u`, `copy_e`.
shp1 : tuple[int]
The shape of lhs operand.
shp2 : tuple[int]
The shape of rhs operand.
Returns
-------
tuple[int]
shape after broadcasting
"""
pad_shp1, pad_shp2 = shp1, shp2
if op == "dot":
if shp1[-1] != shp2[-1]:
raise DGLError("Dot operator is only available for arrays with the "
"same size on last dimension, but got {} and {}."
.format(shp1, shp2))
if op == "copy_u":
return shp1
if op == "copy_e":
return shp2
# operands are padded to have the same dimensionality with leading 1's.
if len(shp1) > len(shp2):
pad_shp2 = (1,) * (len(shp1) - len(shp2)) + shp2
elif len(shp1) < len(shp2):
pad_shp1 = (1,) * (len(shp2) - len(shp1)) + shp1
for d1, d2 in zip(pad_shp1, pad_shp2):
if d1 != d2 and d1 != 1 and d2 != 1:
raise DGLError("Feature shapes {} and {} are not valid for broadcasting."
.format(shp1, shp2))
rst = tuple(max(d1, d2) for d1, d2 in zip(pad_shp1, pad_shp2))
return rst[:-1] + (1,) if op == "dot" else rst
def to_dgl_nd(x):
"""Convert framework-specific tensor/None to dgl ndarray."""
return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray(x)
# map alias of operator name to its actually name that backend could recognize.
op_mapping = {
'+': 'add',
'-': 'sub',
'*': 'mul',
'/': 'div',
'.': 'dot',
'add': 'add',
'sub': 'sub',
'mul': 'mul',
'div': 'div',
'dot': 'dot',
'copy_u': 'copy_u',
'copy_e': 'copy_e'
}
def gspmm(g, op, reduce_op, u, e):
r""" Generalized Sparse Matrix Multiplication interface. It takes the result of
:attr:`op` on source node feature and edge feature, leads to a message on edge.
Then aggregates the message by :attr:`reduce_op` on destination nodes.
.. math::
x_v = \psi_{(u, v, e)\in \mathcal{G}}(\rho(x_u, x_e))
where :math:`x_v` is the returned feature on destination nodes, and :math`x_u`,
:math:`x_e` refers to :attr:`u`, :attr:`e` respectively. :math:`\rho` means binary
operator :attr:`op` and :math:`\psi` means reduce operator :attr:`reduce_op`,
:math:`\mathcal{G}` is the graph we apply gspmm on: :attr:`g`.
Note that this function does not handle gradients.
Parameters
----------
g : DGLHeteroGraph
The input graph.
op : str
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``, ``copy_u``,
``copy_e``, or their alias ``+``, ``-``, ``*``, ``/``, ``.``.
reduce_op : str
Reduce operator, could be ``sum``, ``max``, ``min``.
u : tensor or None
The feature on source nodes, could be None if op is ``copy_e``.
e : tensor or None
The feature on edges, could be None if op is ``copy_u``.
Returns
-------
tensor
The result tensor.
Notes
-----
This function does not handle gradients, and for scalar input features,
we expand its dimension with an additional dimension of length one. (e.g.
(90,) to (90, 1) for a graph with 90 nodes/edges).
"""
if u is not None:
if F.ndim(u) == 1:
u = F.unsqueeze(u, -1)
if e is not None:
if F.ndim(e) == 1:
e = F.unsqueeze(e, -1)
op = op_mapping[op]
ctx = F.context(u) if u is not None else F.context(e)
dtype = F.dtype(u) if u is not None else F.dtype(e)
use_u = (op != 'copy_e')
use_e = (op != 'copy_u')
u_shp = F.shape(u) if use_u else (0,)
e_shp = F.shape(e) if use_e else (0,)
v_shp = (g.number_of_dst_nodes(), ) +\
infer_broadcast_shape(op, u_shp[1:], e_shp[1:])
v = F.zeros(v_shp, dtype, ctx)
use_cmp = reduce_op in ['max', 'min']
arg_u = F.zeros(v_shp, g.idtype, ctx) if use_cmp and use_u else None
arg_e = F.zeros(v_shp, g.idtype, ctx) if use_cmp and use_e else None
if g.number_of_edges() > 0:
gidx = g._graph.get_unitgraph(0, to_dgl_context(ctx))
_CAPI_DGLKernelSpMM(gidx, op, reduce_op,
to_dgl_nd(u), to_dgl_nd(e), to_dgl_nd(v),
to_dgl_nd(arg_u), to_dgl_nd(arg_e))
return v, (arg_u, arg_e)
def gsddmm(g, op, u, v):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. It
takes the result of :attr:`op` on source node feature and destination node
feature, leads to a feature on edge.
.. math::
x_{e} = \phi(x_u, x_v), \forall (u,e,v)\in \mathcal{G}
where :math:`x_{e}` is the returned feature on edges and :math:`x_u`,
:math:`x_v` refers to :attr:`u`, :attr:`v` respectively. :math:`\phi`
is the binary operator :attr:`op`, and :math:`\mathcal{G}` is the graph
we apply gsddmm on: :attr:`g`.
Parameters
----------
g : DGLHeteroGraph
The input graph.
op : str
Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``, ``copy_u``,
or their alias ``+``, ``-``, ``*``, ``/``, ``.``.
u : tensor or None
The feature on source nodes.
v : tensor or None
The feature on destination, could be None if op is ``copy_u``.
Returns
-------
tensor
The result tensor.
Notes
-----
This function does not handle gradients, and for scalar input features,
we expand its dimension with an additional dimension of length one. (e.g.
(90,) to (90, 1) for a graph with 90 nodes/edges).
"""
if u is not None:
if F.ndim(u) == 1:
u = F.unsqueeze(u, -1)
if v is not None:
if F.ndim(v) == 1:
v = F.unsqueeze(v, -1)
op = op_mapping[op]
ctx = F.context(u)
dtype = F.dtype(u)
u_shp = F.shape(u)
v_shp = F.shape(v) if v is not None else (0,)
e_shp = (g.number_of_edges(), ) +\
infer_broadcast_shape(op, u_shp[1:], v_shp[1:])
e = F.zeros(e_shp, dtype, ctx)
if g.number_of_edges() > 0:
gidx = g._graph.get_unitgraph(0, to_dgl_context(ctx))
_CAPI_DGLKernelSDDMM(gidx, op, to_dgl_nd(u), to_dgl_nd(v), to_dgl_nd(e))
return e
_init_api("dgl.sparse")
...@@ -85,7 +85,9 @@ IdArray HStack(IdArray lhs, IdArray rhs) { ...@@ -85,7 +85,9 @@ IdArray HStack(IdArray lhs, IdArray rhs) {
NDArray IndexSelect(NDArray array, IdArray index) { NDArray IndexSelect(NDArray array, IdArray index) {
NDArray ret; NDArray ret;
CHECK_SAME_CONTEXT(array, index); CHECK_SAME_CONTEXT(array, index);
CHECK_EQ(array->ndim, 1) << "Only support select values from 1D array."; CHECK_GE(array->ndim, 1) << "Only support array with at least 1 dimension";
CHECK_EQ(array->shape[0], array.NumElements()) << "Only support tensor"
<< " whose first dimension equals number of elements, e.g. (5,), (5, 1)";
CHECK_EQ(index->ndim, 1) << "Index array must be an 1D array."; CHECK_EQ(index->ndim, 1) << "Index array must be an 1D array.";
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "IndexSelect", { ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "IndexSelect", {
ATEN_DTYPE_SWITCH(array->dtype, DType, "values", { ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
......
/*!
* Copyright (c) 2020 by Contributors
* \file aten/cpu/sddmm.cc
* \brief SDDMM C APIs and definitions.
*/
#include "./sddmm.h"
#include <dgl/array.h>
namespace dgl {
namespace aten {
/*! \brief Generalized SDDMM on Csr format. */
template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat,
NDArray vfeat,
NDArray out) {
SWITCH_OP(op, Op, {
cpu::SDDMMCsr<IdType, DType, Op>(bcast, csr, ufeat, vfeat, out);
});
}
template void SDDMMCsr<kDLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void SDDMMCsr<kDLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void SDDMMCsr<kDLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void SDDMMCsr<kDLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
/*! \brief Generalized SDDMM on Coo format. */
template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op,
const BcastOff& bcast,
const COOMatrix& coo,
NDArray ufeat,
NDArray vfeat,
NDArray out) {
SWITCH_OP(op, Op, {
cpu::SDDMMCoo<IdType, DType, Op>(bcast, coo, ufeat, vfeat, out);
});
}
template void SDDMMCoo<kDLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out);
template void SDDMMCoo<kDLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out);
template void SDDMMCoo<kDLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out);
template void SDDMMCoo<kDLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out);
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/sddmm.h
* \brief SDDMM CPU kernel function header.
*/
#ifndef DGL_ARRAY_CPU_SDDMM_H_
#define DGL_ARRAY_CPU_SDDMM_H_
#include <dgl/array.h>
#include <dgl/bcast.h>
namespace dgl {
namespace aten {
namespace cpu {
/*!
* \brief CPU kernel of g-SDDMM on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param ufeat The feature on source nodes.
* \param vfeat The feature on destination nodes.
* \param out The result feature on edges.
* \note it uses node parallel strategy, different threads are responsible
* for the computation of different nodes.
*/
template <typename IdType, typename DType, typename Op>
void SDDMMCsr(const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat, NDArray vfeat, NDArray out) {
const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = csr.indptr.Ptr<IdType>();
const IdType* indices = csr.indices.Ptr<IdType>();
const IdType* edges = csr.data.Ptr<IdType>();
const DType* X = ufeat.Ptr<DType>();
const DType* Y = vfeat.Ptr<DType>();
const int64_t dim = bcast.out_len,
lhs_dim = bcast.lhs_len,
rhs_dim = bcast.rhs_len,
reduce_size = bcast.reduce_size;
DType* O = out.Ptr<DType>();
#pragma omp parallel for
for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j];
const IdType eid = has_idx? edges[j] : j;
DType* out_off = O + eid * dim;
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off = Op::use_lhs? X + rid * lhs_dim +\
lhs_add * reduce_size : nullptr;
const DType* rhs_off = Op::use_rhs? Y + cid * rhs_dim +\
rhs_add * reduce_size : nullptr;
out_off[k] = Op::Call(lhs_off, rhs_off, reduce_size);
}
}
}
}
/*!
* \brief CPU kernel of g-SDDMM on Coo format.
* \param bcast Broadcast information.
* \param coo The COO matrix.
* \param ufeat The feature on source nodes.
* \param vfeat The feature on destination nodes.
* \param out The result feature on edges.
* \note it uses edge parallel strategy, different threads are responsible
* for the computation of different edges.
*/
template <typename IdType, typename DType, typename Op>
void SDDMMCoo(const BcastOff& bcast,
const COOMatrix& coo,
NDArray ufeat, NDArray vfeat, NDArray out) {
const bool has_idx = !IsNullArray(coo.data);
const IdType* row = coo.row.Ptr<IdType>();
const IdType* col = coo.col.Ptr<IdType>();
const IdType* edges = coo.data.Ptr<IdType>();
const DType* X = ufeat.Ptr<DType>();
const DType* Y = vfeat.Ptr<DType>();
const int64_t dim = bcast.out_len,
lhs_dim = bcast.lhs_len,
rhs_dim = bcast.rhs_len,
reduce_size = bcast.reduce_size;
DType* O = out.Ptr<DType>();
const int64_t nnz = coo.row->shape[0];
#pragma omp parallel for
for (IdType i = 0; i < nnz; ++i) {
const IdType rid = row[i];
const IdType cid = col[i];
const IdType eid = has_idx? edges[i] : i;
DType* out_off = O + eid * dim;
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off = Op::use_lhs? X + rid * lhs_dim +\
lhs_add * reduce_size : nullptr;
const DType* rhs_off = Op::use_rhs? Y + cid * rhs_dim +\
rhs_add * reduce_size : nullptr;
out_off[k] = Op::Call(lhs_off, rhs_off, bcast.reduce_size);
}
}
}
namespace op {
//////////////////////////////// binary operators on CPU ////////////////////////////////
template <typename DType>
struct Add {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
return *lhs_off + *rhs_off;
}
};
template <typename DType>
struct Sub {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
return *lhs_off - *rhs_off;
}
};
template <typename DType>
struct Mul {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
return *lhs_off * *rhs_off;
}
};
template <typename DType>
struct Div {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
return *lhs_off / *rhs_off;
}
};
template <typename DType>
struct CopyLhs {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = false;
inline static DType Call(const DType* lhs_off, const DType*, int64_t len = 1) {
return *lhs_off;
}
};
template <typename DType>
struct CopyRhs {
static constexpr bool use_lhs = false;
static constexpr bool use_rhs = true;
inline static DType Call(const DType* , const DType* rhs_off, int64_t len = 1) {
return *rhs_off;
}
};
template <typename DType>
struct Dot {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
DType rst = 0;
for (int64_t l = 0; l < len; ++l) {
rst += lhs_off[l] * rhs_off[l];
}
return rst;
}
};
#define SWITCH_OP(op, Op, ...) \
do { \
if ((op) == "add") { \
typedef dgl::aten::cpu::op::Add<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "sub") { \
typedef dgl::aten::cpu::op::Sub<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "mul") { \
typedef dgl::aten::cpu::op::Mul<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "div") { \
typedef dgl::aten::cpu::op::Div<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_u") { \
typedef dgl::aten::cpu::op::CopyLhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_e") { \
typedef dgl::aten::cpu::op::CopyRhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "dot") { \
typedef dgl::aten::cpu::op::Dot<DType> Op; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Unsupported SDDMM binary operator: " << op; \
} \
} while (0)
} // namespace op
} // namespace cpu
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_CPU_SDDMM_H_
/*!
* Copyright (c) 2020 by Contributors
* \file kernel/cpu/spmm.cc
* \brief SPMM C APIs and definitions.
*/
#include "./spmm.h"
#include <dgl/array.h>
namespace dgl {
namespace aten {
/*! \brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, typename DType>
void SpMMCsr(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux) {
if (reduce == "sum") {
SWITCH_OP(op, Op, {
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
});
} else if (reduce == "max" || reduce == "min") {
SWITCH_OP(op, Op, {
if (reduce == "max")
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
else
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
} else {
LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
}
}
template void SpMMCsr<kDLCPU, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
/*! \brief Generalized SpMM on Coo format. */
template <int XPU, typename IdType, typename DType>
void SpMMCoo(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const COOMatrix& coo,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux) {
if (reduce == "sum") {
SWITCH_OP(op, Op, {
cpu::SpMMSumCoo<IdType, DType, Op>(bcast, coo, ufeat, efeat, out);
});
} else if (reduce == "max" || reduce == "min") {
SWITCH_OP(op, Op, {
if (reduce == "max")
cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Max<DType>>(
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
else
cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
} else {
LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
}
}
template void SpMMCoo<kDLCPU, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/spmm.h
* \brief SPMM CPU kernel function header.
*/
#ifndef DGL_ARRAY_CPU_SPMM_H_
#define DGL_ARRAY_CPU_SPMM_H_
#include <dgl/array.h>
#include <dgl/bcast.h>
#include <limits>
#include <algorithm>
namespace dgl {
namespace aten {
namespace cpu {
/*!
* \brief CPU kernel of SpMM on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param ufeat The feature on source nodes.
* \param efeat The feature on edges.
* \param out The result feature on destination nodes.
* \note it uses node parallel strategy, different threads are responsible
* for the computation of different nodes.
*/
template <typename IdType, typename DType, typename Op>
void SpMMSumCsr(
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat, NDArray efeat,
NDArray out) {
const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = csr.indptr.Ptr<IdType>();
const IdType* indices = csr.indices.Ptr<IdType>();
const IdType* edges = csr.data.Ptr<IdType>();
const DType* X = ufeat.Ptr<DType>();
const DType* W = efeat.Ptr<DType>();
int64_t dim = bcast.out_len,
lhs_dim = bcast.lhs_len,
rhs_dim = bcast.rhs_len;
DType* O = out.Ptr<DType>();
#pragma omp parallel for
for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim;
for (int64_t k = 0; k < dim; ++k) {
DType accum = 0;
for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j];
const IdType eid = has_idx? edges[j] : j;
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off = Op::use_lhs? X + cid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off = Op::use_rhs? W + eid * rhs_dim + rhs_add : nullptr;
accum += Op::Call(lhs_off, rhs_off);
}
out_off[k] = accum;
}
}
}
/*!
* \brief CPU kernel of SpMM on Coo format.
* \param bcast Broadcast information.
* \param coo The Coo matrix.
* \param ufeat The feature on source nodes.
* \param efeat The feature on edges.
* \param out The result feature on destination nodes.
* \note it uses node parallel strategy, different threads are responsible
* for the computation of different nodes. To avoid possible data hazard,
* we use atomic operators in the reduction phase.
*/
template <typename IdType, typename DType, typename Op>
void SpMMSumCoo(
const BcastOff& bcast,
const COOMatrix& coo,
NDArray ufeat, NDArray efeat,
NDArray out) {
const bool has_idx = !IsNullArray(coo.data);
const IdType* row = coo.row.Ptr<IdType>();
const IdType* col = coo.col.Ptr<IdType>();
const IdType* edges = coo.data.Ptr<IdType>();
const DType* X = ufeat.Ptr<DType>();
const DType* W = efeat.Ptr<DType>();
int64_t dim = bcast.out_len,
lhs_dim = bcast.lhs_len,
rhs_dim = bcast.rhs_len;
DType* O = out.Ptr<DType>();
const int64_t nnz = coo.row->shape[0];
// fill zero elements
memset(O, 0, out.GetSize());
// spmm
#pragma omp parallel for
for (IdType i = 0; i < nnz; ++i) {
const IdType rid = row[i];
const IdType cid = col[i];
const IdType eid = has_idx? edges[i] : i;
DType* out_off = O + cid * dim;
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off = Op::use_lhs? X + rid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off = Op::use_rhs? W + eid * rhs_dim + rhs_add : nullptr;
const DType val = Op::Call(lhs_off, rhs_off);
#pragma omp atomic
out_off[k] += val;
}
}
}
/*!
* \brief CPU kernel of SpMM-Min/Max on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param ufeat The feature on source nodes.
* \param efeat The feature on edges.
* \param out The result feature on destination nodes.
* \param argu Arg-Min/Max on source nodes, which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* \param arge Arg-Min/Max on edges. which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* \note it uses node parallel strategy, different threads are responsible
* for the computation of different nodes.
*/
template <typename IdType, typename DType, typename Op, typename Cmp>
void SpMMCmpCsr(
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat, NDArray efeat,
NDArray out, NDArray argu, NDArray arge) {
const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
const IdType* indices = static_cast<IdType*>(csr.indices->data);
const IdType* edges = has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
const DType* X = Op::use_lhs? static_cast<DType*>(ufeat->data) : nullptr;
const DType* W = Op::use_rhs? static_cast<DType*>(efeat->data) : nullptr;
const int64_t dim = bcast.out_len,
lhs_dim = bcast.lhs_len,
rhs_dim = bcast.rhs_len;
DType* O = static_cast<DType*>(out->data);
IdType* argX = Op::use_lhs? static_cast<IdType*>(argu->data) : nullptr;
IdType* argW = Op::use_rhs? static_cast<IdType*>(arge->data) : nullptr;
#pragma omp parallel for
for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim;
IdType* argx_off = argX + rid * dim;
IdType* argw_off = argW + rid * dim;
for (int64_t k = 0; k < dim; ++k) {
DType accum = Cmp::zero;
IdType ax = 0, aw = 0;
for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j];
const IdType eid = has_idx? edges[j] : j;
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off = Op::use_lhs? X + cid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off = Op::use_rhs? W + eid * rhs_dim + rhs_add : nullptr;
const DType val = Op::Call(lhs_off, rhs_off);
if (Cmp::Call(accum, val)) {
accum = val;
if (Op::use_lhs)
ax = cid;
if (Op::use_rhs)
aw = eid;
}
}
out_off[k] = accum;
if (Op::use_lhs)
argx_off[k] = ax;
if (Op::use_rhs)
argw_off[k] = aw;
}
}
}
/*!
* \brief CPU kernel of SpMM-Min/Max on Coo format.
* \param bcast Broadcast information.
* \param coo The Coo matrix.
* \param ufeat The feature on source nodes.
* \param efeat The feature on edges.
* \param out The result feature on destination nodes.
* \param argu Arg-Min/Max on source nodes, which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* \param arge Arg-Min/Max on edges. which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* \note it uses node parallel strategy, different threads are responsible
* for the computation of different nodes. To avoid possible data hazard,
* we use atomic operators in the reduction phase.
*/
template <typename IdType, typename DType, typename Op, typename Cmp>
void SpMMCmpCoo(
const BcastOff& bcast,
const COOMatrix& coo,
NDArray ufeat, NDArray efeat,
NDArray out, NDArray argu, NDArray arge) {
const bool has_idx = !IsNullArray(coo.data);
const IdType* row = static_cast<IdType*>(coo.row->data);
const IdType* col = static_cast<IdType*>(coo.col->data);
const IdType* edges = has_idx? static_cast<IdType*>(coo.data->data) : nullptr;
const DType* X = Op::use_lhs? static_cast<DType*>(ufeat->data) : nullptr;
const DType* W = Op::use_rhs? static_cast<DType*>(efeat->data) : nullptr;
const int64_t dim = bcast.out_len,
lhs_dim = bcast.lhs_len,
rhs_dim = bcast.rhs_len;
DType* O = static_cast<DType*>(out->data);
IdType* argX = Op::use_lhs? static_cast<IdType*>(argu->data) : nullptr;
IdType* argW = Op::use_rhs? static_cast<IdType*>(arge->data) : nullptr;
const int64_t nnz = coo.row->shape[0];
// fill zero elements
std::fill(O, O + out.NumElements(), Cmp::zero);
// spmm
#pragma omp parallel for
for (IdType i = 0; i < nnz; ++i) {
const IdType rid = row[i];
const IdType cid = col[i];
const IdType eid = has_idx? edges[i] : i;
DType* out_off = O + cid * dim;
IdType* argx_off = Op::use_lhs? argX + cid * dim : nullptr;
IdType* argw_off = Op::use_rhs? argW + cid * dim : nullptr;
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off = Op::use_lhs? X + rid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off = Op::use_rhs? W + eid * rhs_dim + rhs_add : nullptr;
const DType val = Op::Call(lhs_off, rhs_off);
#pragma omp critical
if (Cmp::Call(out_off[k], val)) {
out_off[k] = val;
if (Op::use_lhs)
argx_off[k] = rid;
if (Op::use_rhs)
argw_off[k] = eid;
}
}
}
}
namespace op {
//////////////////////////////// binary operators on CPU ////////////////////////////////
template <typename DType>
struct Add {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off) {
return *lhs_off + *rhs_off;
}
};
template <typename DType>
struct Sub {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off) {
return *lhs_off - *rhs_off;
}
};
template <typename DType>
struct Mul {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off) {
return *lhs_off * *rhs_off;
}
};
template <typename DType>
struct Div {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off) {
return *lhs_off / *rhs_off;
}
};
template <typename DType>
struct CopyLhs {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = false;
inline static DType Call(const DType* lhs_off, const DType* ) {
return *lhs_off;
}
};
template <typename DType>
struct CopyRhs {
static constexpr bool use_lhs = false;
static constexpr bool use_rhs = true;
inline static DType Call(const DType* , const DType* rhs_off) {
return *rhs_off;
}
};
//////////////////////////////// Reduce operators on CPU ////////////////////////////////
template <typename DType>
struct Max {
static constexpr DType zero = std::numeric_limits<DType>::lowest();
// return true if accum should be replaced
inline static DType Call(DType accum, DType val) {
return accum < val;
}
};
template <typename DType>
struct Min {
static constexpr DType zero = std::numeric_limits<DType>::max();
// return true if accum should be replaced
inline static DType Call(DType accum, DType val) {
return accum > val;
}
};
#define SWITCH_OP(op, Op, ...) \
do { \
if ((op) == "add") { \
typedef dgl::aten::cpu::op::Add<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "sub") { \
typedef dgl::aten::cpu::op::Sub<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "mul") { \
typedef dgl::aten::cpu::op::Mul<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "div") { \
typedef dgl::aten::cpu::op::Div<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_u") { \
typedef dgl::aten::cpu::op::CopyLhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_e") { \
typedef dgl::aten::cpu::op::CopyRhs<DType> Op; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Unsupported SpMM binary operator: " << op; \
} \
} while (0)
} // namespace op
} // namespace cpu
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_CPU_SPMM_H_
/*!
* Copyright (c) 2019 by Contributors
* \file array/cuda/atomic.cuh
* \brief Atomic functions
*/
#ifndef DGL_ARRAY_CUDA_ATOMIC_H_
#define DGL_ARRAY_CUDA_ATOMIC_H_
#include <cuda_runtime.h>
#if __CUDA_ARCH__ >= 600
#include <cuda_fp16.h>
#endif
namespace dgl {
namespace aten {
namespace cuda {
// Type trait for selecting code type
template <int Bytes> struct Code { };
template <> struct Code<4> {
typedef unsigned int Type;
};
template <> struct Code<8> {
typedef unsigned long long int Type;
};
// Helper class for converting to/from atomicCAS compatible types.
template <typename T> struct Cast {
typedef typename Code<sizeof(T)>::Type Type;
static __device__ __forceinline__ Type Encode(T val) {
return static_cast<Type>(val);
}
static __device__ __forceinline__ T Decode(Type code) {
return static_cast<T>(code);
}
};
template <> struct Cast<float> {
typedef Code<sizeof(float)>::Type Type;
static __device__ __forceinline__ Type Encode(float val) {
return __float_as_uint(val);
}
static __device__ __forceinline__ float Decode(Type code) {
return __uint_as_float(code);
}
};
template <> struct Cast<double> {
typedef Code<sizeof(double)>::Type Type;
static __device__ __forceinline__ Type Encode(double val) {
return __double_as_longlong(val);
}
static __device__ __forceinline__ double Decode(Type code) {
return __longlong_as_double(code);
}
};
#define DEFINE_ATOMIC(NAME) \
template <typename T> \
__device__ __forceinline__ T Atomic##NAME(T* addr, T val) { \
typedef typename Cast<T>::Type CT; \
CT* addr_as_ui = reinterpret_cast<CT*>(addr); \
CT old = *addr_as_ui; \
CT assumed = old; \
do { \
assumed = old; \
old = atomicCAS(addr_as_ui, assumed, \
Cast<T>::Encode(OP(val, Cast<T>::Decode(old)))); \
} while (assumed != old); \
return Cast<T>::Decode(old); \
}
#define OP(a, b) max(a, b)
DEFINE_ATOMIC(Max)
#undef OP
#define OP(a, b) min(a, b)
DEFINE_ATOMIC(Min)
#undef OP
#define OP(a, b) a + b
DEFINE_ATOMIC(Add)
#undef OP
#if __CUDA_ARCH__ >= 200
template <>
__device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) {
return atomicAdd(addr, val);
}
#endif // __CUDA_ARCH__
#if __CUDA_ARCH__ >= 600
template <>
__device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {
return atomicAdd(addr, val);
}
#endif
#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000
#if __CUDA_ARCH__ >= 600
template <>
__device__ __forceinline__ __half2 AtomicAdd<__half2>(__half2* addr, __half2 val) {
return atomicAdd(addr, val);
}
#endif // __CUDA_ARCH__
#if __CUDA_ARCH__ >= 700
template <>
__device__ __forceinline__ __half AtomicAdd<__half>(__half* addr, __half val) {
return atomicAdd(addr, val);
}
#endif // __CUDA_ARCH__
#endif
#define OP(a, b) a * b
DEFINE_ATOMIC(Mul)
#undef OP
} // namespace cuda
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_CUDA_ATOMIC_H_
...@@ -82,6 +82,9 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) { ...@@ -82,6 +82,9 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
workspace)); workspace));
device->FreeWorkspace(row->ctx, workspace); device->FreeWorkspace(row->ctx, workspace);
// The row and column field have already been reordered according
// to data, thus the data field will be deprecated.
coo.data = aten::NullArray();
return coo; return coo;
} }
......
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/functor.cuh
* \brief Functors for template on CUDA
*/
#ifndef DGL_ARRAY_CUDA_FUNCTOR_CUH_
#define DGL_ARRAY_CUDA_FUNCTOR_CUH_
namespace dgl {
namespace aten {
namespace cuda {
/////////////////////////////// CUDA binary operators ///////////////////////////////
namespace binary {
template <typename DType>
struct Add {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call(
const DType *lhs, const DType *rhs, int64_t len = 1) {
return lhs[0] + rhs[0];
}
};
template <typename DType>
struct Sub {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call(
const DType *lhs, const DType *rhs, int64_t len = 1) {
return lhs[0] - rhs[0];
}
};
template <typename DType>
struct Mul {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call(
const DType *lhs, const DType *rhs, int64_t len = 1) {
return lhs[0] * rhs[0];
}
};
template <typename DType>
struct Div {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call(
const DType *lhs, const DType *rhs, int64_t len = 1) {
return lhs[0] / rhs[0];
}
};
template <typename DType>
struct CopyU {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = false;
static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call(
const DType *lhs, const DType *rhs, int64_t len = 1) {
return lhs[0];
}
};
template <typename DType>
struct CopyE {
static constexpr bool use_lhs = false;
static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call(
const DType *lhs, const DType *rhs, int64_t len = 1) {
return rhs[0];
}
};
template <typename DType>
struct Dot {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = true;
static __device__ __forceinline__ DType Call(
const DType *lhs, const DType *rhs, int64_t len = 1) {
DType rst = static_cast<DType>(0);
for (int64_t i = 0; i < len; ++i) {
rst += lhs[i] * rhs[i];
}
return rst;
}
};
} // end of namespace binary
/////////////////////////////// CUDA reduce operators ///////////////////////////////
namespace reduce {
template <typename Idx,
typename DType,
bool atomic=false>
struct Sum {
static constexpr DType zero = 0;
static constexpr bool require_arg = false;
static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
DType val, Idx uid, Idx eid) {
if (!atomic) {
*out_buf += val;
} else {
cuda::AtomicAdd(out_buf, val);
}
}
static __device__ __forceinline__ void CallArg(Idx fid,
Idx *arg_u_buf, Idx *arg_e_buf,
DType val, DType val_ref, Idx uid, Idx eid) {}
};
template <typename Idx,
typename DType,
bool atomic=false>
struct Max {
static constexpr DType zero = std::numeric_limits<DType>::lowest();
static constexpr bool require_arg = true;
static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
DType val, Idx uid, Idx eid) {
if (!atomic) {
if (*out_buf < val) {
*out_buf = val;
*arg_u_buf = uid;
*arg_e_buf = eid;
}
} else {
cuda::AtomicMax(out_buf, val);
}
}
static __device__ __forceinline__ void CallArg(Idx fid,
Idx *arg_u_buf, Idx *arg_e_buf,
DType val, DType val_ref, Idx uid, Idx eid) {
if (atomic) {
if (val == val_ref) {
if (arg_u_buf)
arg_u_buf[fid] = uid;
if (arg_e_buf)
arg_e_buf[fid] = eid;
}
}
}
};
template <typename Idx,
typename DType,
bool atomic=false>
struct Min {
static constexpr DType zero = std::numeric_limits<DType>::max();
static constexpr bool require_arg = true;
static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
DType val, Idx uid, Idx eid) {
if (!atomic) {
if (*out_buf > val) {
*out_buf = val;
*arg_u_buf = uid;
*arg_e_buf = eid;
}
} else {
cuda::AtomicMin(out_buf, val);
}
}
static __device__ __forceinline__ void CallArg(Idx fid,
Idx *arg_u_buf, Idx *arg_e_buf,
DType val, DType val_ref, Idx uid, Idx eid) {
if (atomic) {
if (val == val_ref) {
if (arg_u_buf)
arg_u_buf[fid] = uid;
if (arg_e_buf)
arg_e_buf[fid] = eid;
}
}
}
};
} // namespace reduce
} // namespace cuda
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_CUDA_FUNCTOR_CUH_
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/macro.cuh
* \brief Macro to call SPMM/SDDMM cuda kernels.
*/
#ifndef DGL_ARRAY_CUDA_MACRO_CUH_
#define DGL_ARRAY_CUDA_MACRO_CUH_
///////////////////////// Dispatchers //////////////////////////
/* Macro used for switching between broadcasting and non-broadcasting kernels.
* It also copies the auxiliary information for calculating broadcasting offsets
* to GPU.
*/
#define BCAST_IDX_CTX_SWITCH(BCAST, EDGE_MAP, CTX, LHS_OFF, RHS_OFF, ...) do { \
const BcastOff &info = (BCAST); \
if (!info.use_bcast) { \
constexpr bool UseBcast = false; \
if ((EDGE_MAP)) { \
constexpr bool UseIdx = true; \
{ __VA_ARGS__ } \
} else { \
constexpr bool UseIdx = false; \
{ __VA_ARGS__ } \
} \
} else { \
constexpr bool UseBcast = true; \
const DLContext ctx = (CTX); \
const auto device = runtime::DeviceAPI::Get(ctx); \
(LHS_OFF) = static_cast<int64_t*>( \
device->AllocWorkspace(ctx, sizeof(int64_t) * info.lhs_offset.size())); \
CUDA_CALL(cudaMemcpy((LHS_OFF), &info.lhs_offset[0], \
sizeof(int64_t) * info.lhs_offset.size(), cudaMemcpyHostToDevice)); \
(RHS_OFF) = static_cast<int64_t*>( \
device->AllocWorkspace(ctx, sizeof(int64_t) * info.rhs_offset.size())); \
CUDA_CALL(cudaMemcpy((RHS_OFF), &info.rhs_offset[0], \
sizeof(int64_t) * info.rhs_offset.size(), cudaMemcpyHostToDevice)); \
if ((EDGE_MAP)) { \
constexpr bool UseIdx = true; \
{ __VA_ARGS__ } \
} else { \
constexpr bool UseIdx = false; \
{ __VA_ARGS__ } \
} \
device->FreeWorkspace(ctx, (LHS_OFF)); \
device->FreeWorkspace(ctx, (RHS_OFF)); \
} \
} while (0)
#endif
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/sddmm.cu
* \brief SDDMM C APIs and definitions.
*/
#include <dgl/array.h>
#include "./sddmm.cuh"
#include "./functor.cuh"
namespace dgl {
namespace aten {
#define SWITCH_OP(op, Op, ...) \
do { \
if ((op) == "add") { \
typedef cuda::binary::Add<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "sub") { \
typedef cuda::binary::Sub<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "mul") { \
typedef cuda::binary::Mul<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "div") { \
typedef cuda::binary::Div<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_u") { \
typedef cuda::binary::CopyU<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_e") { \
typedef cuda::binary::CopyE<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "dot") { \
typedef cuda::binary::Dot<DType> Op; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Unsupported SpMM/SDDMM binary operator: " << op; \
} \
} while (0)
/*!
* \brief CUDA implementation of g-SDDMM on Csr format.
*/
template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat,
NDArray vfeat,
NDArray out) {
SWITCH_OP(op, Op, {
cuda::SDDMMCsr<IdType, DType, Op>(bcast, csr, ufeat, vfeat, out);
});
}
/*!
* \brief CUDA implementation of g-SDDMM on Coo format.
*/
template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op,
const BcastOff& bcast,
const COOMatrix& coo,
NDArray ufeat,
NDArray vfeat,
NDArray out) {
SWITCH_OP(op, Op, {
cuda::SDDMMCoo<IdType, DType, Op>(bcast, coo, ufeat, vfeat, out);
});
}
template void SDDMMCsr<kDLGPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray vfeat, NDArray out);
template void SDDMMCsr<kDLGPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray vfeat, NDArray out);
template void SDDMMCsr<kDLGPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray vfeat, NDArray out);
template void SDDMMCsr<kDLGPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray vfeat, NDArray out);
template void SDDMMCoo<kDLGPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray vfeat, NDArray out);
template void SDDMMCoo<kDLGPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray vfeat, NDArray out);
template void SDDMMCoo<kDLGPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray vfeat, NDArray out);
template void SDDMMCoo<kDLGPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray vfeat, NDArray out);
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/sddmm.cuh
* \brief SDDMM CUDA kernel function header.
*/
#ifndef DGL_ARRAY_CUDA_SDDMM_CUH_
#define DGL_ARRAY_CUDA_SDDMM_CUH_
#include <dgl/bcast.h>
#include "macro.cuh"
#include "atomic.cuh"
#include "functor.cuh"
#include "../../cuda_utils.h"
#include "../../runtime/cuda/cuda_common.h"
namespace dgl {
using namespace cuda;
namespace aten {
namespace cuda {
/*!
* \brief CUDA kernel of g-SDDMM on Coo format.
* \note it uses edge parallel strategy, different threadblocks (on y-axis)
* is responsible for the computation on different edges. Threadblocks
* on the x-axis are responsible for the computation on different positions
* in feature dimension.
*/
template <typename Idx, typename DType, typename BinaryOp,
bool UseBcast = false, bool UseIdx = false>
__global__ void SDDMMCooKernel(
const DType *ufeat, const DType *vfeat, DType *out,
const Idx *row, const Idx *col, const Idx* edge_map,
int64_t N, int64_t M, int64_t E, int64_t reduce_size,
const int64_t *ubcast_off, const int64_t *vbcast_off,
int64_t ufeat_len, int64_t vfeat_len, int64_t out_len) {
// SDDMM with COO.
Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y;
while (ty < E) {
const Idx src = _ldg(row + ty);
const Idx dst = _ldg(col + ty);
const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
const DType* lhsoff = BinaryOp::use_lhs ?
(ufeat + src * ufeat_len): nullptr;
const DType* rhsoff = BinaryOp::use_rhs ?
(vfeat + dst * vfeat_len): nullptr;
DType* outoff = out + eid * out_len;
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = blockDim.x * gridDim.x;
while (tx < out_len) {
const Idx lhs_add = UseBcast ? ubcast_off[tx] : tx;
const Idx rhs_add = UseBcast ? vbcast_off[tx] : tx;
DType val = BinaryOp::Call(
lhsoff + lhs_add * reduce_size,
rhsoff + rhs_add * reduce_size,
reduce_size);
outoff[tx] = val;
tx += stride_x;
}
ty += stride_y;
}
}
// Binary search the row_offsets to find the source node of the edge id.
template <typename Idx>
__device__ __forceinline__ Idx BinarySearchSrc(const Idx *array, Idx length, Idx eid) {
Idx lo = 0, hi = length - 1;
while (lo < hi) {
Idx mid = (lo + hi) >> 1;
if (_ldg(array + mid) <= eid) {
lo = mid + 1;
} else {
hi = mid;
}
}
// INVARIANT: lo == hi
if (_ldg(array + hi) == eid) {
return hi;
} else {
return hi - 1;
}
}
/*!
* \brief CUDA kernel of g-SDDMM on Csr format.
* \note it uses edge parallel strategy, different threadblocks (on y-axis)
* is responsible for the computation on different edges. Threadblocks
* on the x-axis are responsible for the computation on different positions
* in feature dimension.
* To efficiently find the source node idx and destination node index of an
* given edge on Csr format, it uses binary search (time complexity O(log N)).
*/
template <typename Idx, typename DType, typename BinaryOp,
bool UseBcast = false, bool UseIdx = false>
__global__ void SDDMMCsrKernel(
const DType *ufeat, const DType *vfeat, DType *out,
const Idx *indptr, const Idx *indices, const Idx* edge_map,
int64_t N, int64_t M, int64_t E, int64_t reduce_size,
int64_t *ubcast_off, int64_t *vbcast_off,
int64_t ufeat_len, int64_t vfeat_len, int64_t out_len) {
// SDDMM with Csr.
Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y;
while (ty < E) {
const Idx src = BinarySearchSrc<Idx>(indptr, N + 1, ty);
const Idx dst = _ldg(indices + ty);
const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t stride_x = blockDim.x * gridDim.x;
const DType* lhsoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len): nullptr;
const DType* rhsoff = BinaryOp::use_rhs ? (vfeat + dst * vfeat_len): nullptr;
DType* outoff = out + eid * out_len;
while (tx < out_len) {
const Idx lhs_add = UseBcast ? ubcast_off[tx] : tx;
const Idx rhs_add = UseBcast ? vbcast_off[tx] : tx;
DType val = BinaryOp::Call(
lhsoff + lhs_add * reduce_size,
rhsoff + rhs_add * reduce_size,
reduce_size);
outoff[tx] = val;
tx += stride_x;
}
ty += stride_y;
}
}
/*!
* \brief CUDA implementation of g-SDDMM on Coo format.
* \param bcast Broadcast information.
* \param coo The Coo matrix.
* \param ufeat The feature on source nodes.
* \param vfeat The feature on destination nodes.
* \param out The result feature on edges.
*/
template <typename Idx, typename DType, typename Op>
void SDDMMCoo(
const BcastOff& bcast,
const COOMatrix& coo,
NDArray ufeat,
NDArray vfeat,
NDArray out) {
const Idx *row = coo.row.Ptr<Idx>();
const Idx *col = coo.col.Ptr<Idx>();
const Idx *edge_map = coo.data.Ptr<Idx>();
const DType *ufeat_data = ufeat.Ptr<DType>();
const DType *vfeat_data = vfeat.Ptr<DType>();
DType *out_data = out.Ptr<DType>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
int64_t *ubcast_off = nullptr, *vbcast_off = nullptr;
int64_t len = bcast.out_len,
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
int64_t reduce_dim = bcast.reduce_size;
const int64_t nnz = coo.row->shape[0];
const int ntx = FindNumThreads(len);
const int nty = CUDA_MAX_NUM_THREADS / ntx;
const int nbx = (len + ntx - 1) / ntx;
const int nby = FindNumBlocks<'y'>((nnz + nty - 1) / nty);
//LOG(INFO) << "nblks=(" << nbx << ", " << nby << ") nthrs=(" << ntx << ", " << nty << ")";
const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty);
const bool use_idx = !IsNullArray(coo.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, vbcast_off, {
SDDMMCooKernel<Idx, DType, Op, UseBcast, UseIdx>
<<<nblks, nthrs, 0, thr_entry->stream>>>(
ufeat_data, vfeat_data, out_data,
row, col, edge_map,
coo.num_rows, coo.num_cols, nnz, reduce_dim,
ubcast_off, vbcast_off,
lhs_len, rhs_len, len
);
});
}
/*!
* \brief CUDA implementation of g-SDDMM on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param ufeat The feature on source nodes.
* \param vfeat The feature on destination nodes.
* \param out The result feature on edges.
*/
template <typename Idx, typename DType, typename Op>
void SDDMMCsr(
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat,
NDArray vfeat,
NDArray out) { const Idx *indptr = csr.indptr.Ptr<Idx>();
const Idx *indices = csr.indices.Ptr<Idx>();
const Idx *edge_map = csr.data.Ptr<Idx>();
const DType *ufeat_data = ufeat.Ptr<DType>();
const DType *vfeat_data = vfeat.Ptr<DType>();
DType *out_data = out.Ptr<DType>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
int64_t N = csr.num_rows, M = csr.num_cols, E = csr.indices->shape[0];
int64_t *ubcast_off = nullptr, *vbcast_off = nullptr;
int64_t len = bcast.out_len,
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
int64_t reduce_dim = bcast.reduce_size;
const int ntx = FindNumThreads(len);
const int nty = CUDA_MAX_NUM_THREADS / ntx;
const int nbx = (len + ntx - 1) / ntx;
const int nby = FindNumBlocks<'y'>((E + nty - 1) / nty);
const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty);
const bool use_idx = !IsNullArray(csr.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, vbcast_off, {
SDDMMCsrKernel<Idx, DType, Op, UseBcast, UseIdx>
<<<nblks, nthrs, 0, thr_entry->stream>>>(
ufeat_data, vfeat_data, out_data,
indptr, indices, edge_map,
N, M, E, reduce_dim,
ubcast_off, vbcast_off,
lhs_len, rhs_len, len
);
});
}
} // namespace cuda
} // namespace aten
} // namespace dgl
#endif
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/spmm.cu
* \brief SPMM C APIs and definitions.
*/
#include <dgl/array.h>
#include "./spmm.cuh"
#include "./functor.cuh"
#include "../../runtime/cuda/cuda_common.h"
namespace dgl {
using namespace cuda;
namespace aten {
namespace {
/*! \brief Fill the vector started from ptr of size length with val */
template <typename DType>
void _Fill(DType* ptr, size_t length, DType val) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
int nt = FindNumThreads(length);
int nb = (length + nt - 1) / nt; // on x-axis, no need to worry about upperbound.
cuda::_FillKernel<<<nb, nt, 0, thr_entry->stream>>>(ptr, length, val);
}
} // namespace
namespace cusparse {
template <typename DType>
cusparseStatus_t Xcsrmm2(cusparseHandle_t handle, cusparseOperation_t transA,
cusparseOperation_t transB, int m, int n, int k, int nnz,
const DType* alpha, const cusparseMatDescr_t descrA,
const DType* csrValA, const int* csrRowPtrA, const int* csrColIndA,
const DType* B, int ldb, const DType* beta, DType* C, int ldc) {
LOG(INFO) << "Not supported dtype";
return CUSPARSE_STATUS_EXECUTION_FAILED;
}
template <>
cusparseStatus_t Xcsrmm2<float>(cusparseHandle_t handle, cusparseOperation_t transA,
cusparseOperation_t transB, int m, int n, int k, int nnz,
const float* alpha, const cusparseMatDescr_t descrA,
const float* csrValA, const int* csrRowPtrA, const int* csrColIndA,
const float* B, int ldb, const float* beta, float* C, int ldc) {
return cusparseScsrmm2(handle, transA, transB, m, n, k, nnz,
alpha, descrA, csrValA, csrRowPtrA, csrColIndA,
B, ldb, beta, C, ldc);
}
template <>
cusparseStatus_t Xcsrmm2<double>(cusparseHandle_t handle, cusparseOperation_t transA,
cusparseOperation_t transB, int m, int n, int k, int nnz,
const double* alpha, const cusparseMatDescr_t descrA,
const double* csrValA, const int* csrRowPtrA, const int* csrColIndA,
const double* B, int ldb, const double* beta, double* C, int ldc) {
return cusparseDcsrmm2(handle, transA, transB, m, n, k, nnz,
alpha, descrA, csrValA, csrRowPtrA, csrColIndA,
B, ldb, beta, C, ldc);
}
template <typename DType>
cublasStatus_t Xgeam(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
const DType* alpha, const DType* A, int lda,
const DType* beta, const DType* B, int ldb,
DType* C, int ldc) {
LOG(INFO) << "Not supported dtype";
return CUBLAS_STATUS_EXECUTION_FAILED;
}
template <>
cublasStatus_t Xgeam<float>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
const float* alpha, const float* A, int lda,
const float* beta, const float* B, int ldb,
float* C, int ldc) {
return cublasSgeam(handle, transa, transb, m, n, alpha, A, lda,
beta, B, ldb, C, ldc);
}
template <>
cublasStatus_t Xgeam<double>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
const double* alpha, const double* A, int lda,
const double* beta, const double* B, int ldb,
double* C, int ldc) {
return cublasDgeam(handle, transa, transb, m, n, alpha, A, lda,
beta, B, ldb, C, ldc);
}
/*! Cusparse implementation of SpMM on Csr format. */
template <typename DType>
void CusparseCsrmm2(
const DLContext& ctx,
const CSRMatrix& csr,
const DType* B_data, const DType* A_data,
DType* C_data,
int x_length) {
// We use csrmm2 to perform following operation:
// C = A x B, where A is a sparse matrix in csr format, B is the dense matrix for node
// feature tensor. However, since cusparse only supports column-major, while our tensor
// is stored in row-major, the actual computation is:
// C = trans(A x trans(B)).
// Currently, we use cublasXgeam to implement transposition and allocate intermediate
// workspace memory for this.
const int m = csr.num_rows;
const int n = x_length;
const int k = csr.num_cols;
const int nnz = csr.indices->shape[0];
const DType alpha = 1.0;
const DType beta = 0.0;
// device
auto device = runtime::DeviceAPI::Get(ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
// allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
}
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream));
// allocate matrix for temporary transposed output
DType* trans_out = static_cast<DType*>(device->AllocWorkspace(ctx, m * n * sizeof(DType)));
// all one data array
DType* valptr = nullptr;
if (!A_data) {
valptr = static_cast<DType*>(device->AllocWorkspace(ctx, nnz * sizeof(DType)));
_Fill(valptr, nnz, static_cast<DType>(1.));
}
cusparseMatDescr_t descr;
CUSPARSE_CALL(cusparseCreateMatDescr(&descr));
CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));
CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO));
CUSPARSE_CALL(Xcsrmm2<DType>(
thr_entry->cusparse_handle,
CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_TRANSPOSE,
m, n, k, nnz, &alpha,
descr, (valptr)? valptr : A_data,
static_cast<int32_t*>(csr.indptr->data),
static_cast<int32_t*>(csr.indices->data),
B_data, n, &beta, trans_out, m));
if (valptr)
device->FreeWorkspace(ctx, valptr);
// transpose the output matrix
if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, thr_entry->stream));
CUBLAS_CALL(Xgeam<DType>(
thr_entry->cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
n, m,
&alpha, trans_out, m,
&beta, nullptr, n,
C_data, n));
device->FreeWorkspace(ctx, trans_out);
}
} // namespace cusparse
#define SWITCH_OP(op, Op, ...) \
do { \
if ((op) == "add") { \
typedef cuda::binary::Add<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "sub") { \
typedef cuda::binary::Sub<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "mul") { \
typedef cuda::binary::Mul<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "div") { \
typedef cuda::binary::Div<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_u") { \
typedef cuda::binary::CopyU<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_e") { \
typedef cuda::binary::CopyE<DType> Op; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Unsupported SpMM binary operator: " << op; \
} \
} while (0)
/*!
* \brief CUDA implementation of g-SpMM on Csr format.
* \note use cusparse if the reduce operator is `sum` and there is
* no broadcast, use dgl's kernel in other cases.
*/
template <int XPU, typename IdType, typename DType>
void SpMMCsr(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux) {
if (reduce == "sum") {
if (sizeof(IdType) == 4 && op == "copy_u") {
int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i)
x_length *= ufeat->shape[i];
cusparse::CusparseCsrmm2<DType>(
ufeat->ctx, csr,
static_cast<DType*>(ufeat->data),
nullptr,
static_cast<DType*>(out->data),
x_length);
} else if (sizeof(IdType) == 4 && op == "mul" && efeat.NumElements() == csr.indices->shape[0]) {
int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i)
x_length *= ufeat->shape[i];
if (!IsNullArray(csr.data))
efeat = IndexSelect(efeat, csr.data);
cusparse::CusparseCsrmm2<DType>(
ufeat->ctx, csr,
static_cast<DType*>(ufeat->data),
static_cast<DType*>(efeat->data),
static_cast<DType*>(out->data),
x_length);
} else {
SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
bcast, csr, ufeat, efeat, out, NullArray(), NullArray());
});
}
} else if (reduce == "max") {
SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
} else if (reduce == "min") {
SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
} else {
LOG(FATAL) << "Not implemented";
}
}
/*!
* \brief CUDA implementation of g-SpMM on Coo format.
*/
template <int XPU, typename IdType, typename DType>
void SpMMCoo(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const COOMatrix& coo,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux) {
if (reduce == "sum") {
SWITCH_OP(op, Op, {
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Sum<IdType, DType, true> > (
bcast, coo, ufeat, efeat, out, NullArray(), NullArray());
});
} else if (reduce == "max") {
SWITCH_OP(op, Op, {
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Max<IdType, DType, true> > (
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
} else if (reduce == "min") {
SWITCH_OP(op, Op, {
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Min<IdType, DType, true> > (
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
} else {
LOG(FATAL) << "Not implemented";
}
}
template void SpMMCsr<kDLGPU, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLGPU, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLGPU, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLGPU, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLGPU, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLGPU, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLGPU, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLGPU, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/spmm.cuh
* \brief SPMM CUDA kernel function header.
*/
#ifndef DGL_ARRAY_CUDA_SPMM_CUH_
#define DGL_ARRAY_CUDA_SPMM_CUH_
#include <dgl/bcast.h>
#include "macro.cuh"
#include "atomic.cuh"
#include "../../cuda_utils.h"
#include "../../runtime/cuda/cuda_common.h"
namespace dgl {
using namespace cuda;
namespace aten {
namespace cuda {
/*!
* \brief CUDA Kernel of filling the vector started from ptr of size length
* with val.
* \note internal use only.
*/
template <typename DType>
__global__ void _FillKernel(DType* ptr, size_t length, DType val) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
ptr[tx] = val;
tx += stride_x;
}
}
/*!
* \brief CUDA kernel of g-SpMM on Coo format.
* \note it uses edge parallel strategy, different threadblocks (on y-axis)
* is responsible for the computation on different edges. Threadblocks
* on the x-axis are responsible for the computation on different positions
* in feature dimension.
* To avoid possible data hazards, it uses atomic operators for reduction.
*/
template <typename Idx, typename DType,
typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false>
__global__ void SpMMCooKernel(
const DType *ufeat, const DType *efeat, DType *out, Idx *arg_u, Idx *arg_e,
const Idx *row, const Idx *col, const Idx* edge_map,
int64_t N, int64_t M, int64_t E,
int64_t *ubcast_off, int64_t *ebcast_off,
int64_t ufeat_len, int64_t efeat_len, int64_t out_len) {
// SPMM with COO.
Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y;
while (ty < E) {
const Idx src = _ldg(row + ty);
const Idx dst = _ldg(col + ty);
const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t stride_x = blockDim.x * gridDim.x;
const DType* uoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len): nullptr;
const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len): nullptr;
DType* outoff = out + dst * out_len;
while (tx < out_len) {
const int64_t lhs_add = UseBcast ? ubcast_off[tx] : tx;
const int64_t rhs_add = UseBcast ? ebcast_off[tx] : tx;
DType val = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
Idx* arguoff = nullptr; // arguoff is not used in SpMMCoo.
Idx* argeoff = nullptr; // argeoff is not used in SpMMCoo.
ReduceOp::Call(outoff + tx, arguoff, argeoff, val, src, eid);
tx += stride_x;
}
ty += stride_y;
}
}
/*!
* \brief CUDA kernel to compute argu and arge in g-SpMM on Coo format.
* \note it uses edge parallel strategy, different threadblocks (on y-axis)
* is responsible for the computation on different edges. Threadblocks
* on the x-axis are responsible for the computation on different positions
* in feature dimension.
*/
template <typename Idx, typename DType,
typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false>
__global__ void ArgSpMMCooKernel(
const DType *ufeat, const DType *efeat, DType *out, Idx *arg_u, Idx *arg_e,
const Idx *row, const Idx *col, const Idx* edge_map,
int64_t N, int64_t M, int64_t E,
int64_t *ubcast_off, int64_t *ebcast_off,
int64_t ufeat_len, int64_t efeat_len, int64_t out_len) {
// SPMM with COO arg max/min.
Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y;
while (ty < E) {
const Idx src = _ldg(row + ty);
const Idx dst = _ldg(col + ty);
const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t stride_x = blockDim.x * gridDim.x;
const DType* uoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len): nullptr;
const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len): nullptr;
const DType* outoff = out + dst * out_len;
Idx* arguoff = BinaryOp::use_lhs ? (arg_u + dst * out_len): nullptr;
Idx* argeoff = BinaryOp::use_rhs ? (arg_e + dst * out_len): nullptr;
while (tx < out_len) {
int64_t lhs_add = UseBcast ? ubcast_off[tx] : tx;
int64_t rhs_add = UseBcast ? ebcast_off[tx] : tx;
DType val = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
ReduceOp::CallArg(tx, arguoff, argeoff, val, outoff[tx], src, eid);
tx += stride_x;
}
ty += stride_y;
}
}
/*!
* \brief CUDA kernel of g-SpMM on Coo format.
* \note it uses node parallel strategy, different threadblocks (on y-axis)
* is responsible for the computation on different destination nodes.
* Threadblocks on the x-axis are responsible for the computation on
* different positions in feature dimension.
*/
template <typename Idx, typename DType,
typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false>
__global__ void SpMMCsrKernel(
const DType *ufeat, const DType *efeat, DType *out, Idx *arg_u, Idx *arg_e,
const Idx *indptr, const Idx *indices, const Idx *edge_map,
int64_t num_rows, int64_t num_cols, int64_t nnz,
int64_t *ubcast_off, int64_t *ebcast_off,
int64_t ufeat_len, int64_t efeat_len, int64_t out_len) {
// SPMM with CSR.
int ty = blockIdx.y * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y;
const int stride_x = blockDim.x * gridDim.x;
while (ty < num_rows) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
while (tx < out_len) {
DType local_accum = ReduceOp::zero;
Idx local_argu = 0, local_arge = 0;
const int lhs_add = UseBcast ? ubcast_off[tx] : tx;
const int rhs_add = UseBcast ? ebcast_off[tx] : tx;
for (Idx i = indptr[ty]; i < indptr[ty + 1]; ++i) {
const Idx eid = UseIdx ? _ldg(edge_map + i) : i;
const Idx cid = _ldg(indices + i);
const DType* uoff = BinaryOp::use_lhs ? (ufeat + cid * ufeat_len): nullptr;
const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len): nullptr;
DType out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
ReduceOp::Call(&local_accum, &local_argu, &local_arge, out, cid, eid);
}
out[ty * out_len + tx] = local_accum;
if (ReduceOp::require_arg && BinaryOp::use_lhs)
arg_u[ty * out_len + tx] = local_argu;
if (ReduceOp::require_arg && BinaryOp::use_rhs)
arg_e[ty * out_len + tx] = local_arge;
tx += stride_x;
}
ty += stride_y;
}
}
/*!
* \brief CUDA implementation of g-SpMM on Coo format.
* \param bcast Broadcast information.
* \param coo The Coo matrix.
* \param ufeat The feature on source nodes.
* \param efeat The feature on edges.
* \param out The result feature on destination nodes.
* \param argu Arg-Min/Max on source nodes, which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* \param arge Arg-Min/Max on edges. which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
*/
template <typename Idx, typename DType,
typename BinaryOp, typename ReduceOp>
void SpMMCoo(
const BcastOff& bcast,
const COOMatrix& coo,
NDArray ufeat, NDArray efeat,
NDArray out, NDArray argu, NDArray arge) {
const Idx *row = coo.row.Ptr<Idx>(),
*col = coo.col.Ptr<Idx>(),
*edge_map = coo.data.Ptr<Idx>();
const DType *ufeat_data = ufeat.Ptr<DType>(),
*efeat_data = efeat.Ptr<DType>();
DType *out_data = out.Ptr<DType>();
Idx *argu_data = argu.Ptr<Idx>(),
*arge_data = arge.Ptr<Idx>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const int64_t N = coo.num_rows, M = coo.num_cols, E = coo.row->shape[0];
int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
int64_t len = bcast.out_len,
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
int64_t out_size = out.NumElements();
const int nt = FindNumThreads(out_size);
const int nb = (out_size + nt - 1) / nt;
_FillKernel<<<nt, nb, 0, thr_entry->stream>>>(out_data, out_size, ReduceOp::zero);
const int ntx = FindNumThreads(len);
const int nty = CUDA_MAX_NUM_THREADS / ntx;
const int nbx = (len + ntx - 1) / ntx;
const int nby = FindNumBlocks<'y'>((E + nty - 1) / nty);
//LOG(INFO) << "nblks=(" << nbx << ", " << nby << ") nthrs=(" << ntx << ", " << nty << ")";
const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty);
const bool use_idx = !IsNullArray(coo.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, {
SpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>
<<<nblks, nthrs, 0, thr_entry->stream>>>(
ufeat_data, efeat_data, out_data, argu_data, arge_data,
row, col, edge_map,
N, M, E,
ubcast_off, ebcast_off,
lhs_len, rhs_len, len
);
if (ReduceOp::require_arg) {
ArgSpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>
<<<nblks, nthrs, 0, thr_entry->stream>>>(
ufeat_data, efeat_data, out_data, argu_data, arge_data,
row, col, edge_map,
N, M, E,
ubcast_off, ebcast_off,
lhs_len, rhs_len, len
);
}
});
}
/*!
* \brief CUDA implementation of g-SpMM on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param ufeat The feature on source nodes.
* \param efeat The feature on edges.
* \param out The result feature on destination nodes.
* \param argu Arg-Min/Max on source nodes, which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* \param arge Arg-Min/Max on edges. which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
*/
template <typename Idx, typename DType,
typename BinaryOp, typename ReduceOp>
void SpMMCsr(
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat, NDArray efeat,
NDArray out, NDArray argu, NDArray arge) {
const Idx *indptr = csr.indptr.Ptr<Idx>();
const Idx *indices = csr.indices.Ptr<Idx>();
const Idx *edge_map = csr.data.Ptr<Idx>();
const DType *ufeat_data = ufeat.Ptr<DType>();
const DType *efeat_data = efeat.Ptr<DType>();
DType *out_data = out.Ptr<DType>();
Idx* argu_data = argu.Ptr<Idx>();
Idx* arge_data = arge.Ptr<Idx>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
int64_t len = bcast.out_len,
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
const int ntx = FindNumThreads(len);
const int nty = CUDA_MAX_NUM_THREADS / ntx;
const int nbx = (len + ntx - 1) / ntx;
const int nby = FindNumBlocks<'y'>((csr.num_rows + nty - 1) / nty);
//LOG(INFO) << "nblks=(" << nbx << ", " << nby << ") nthrs=(" << ntx << ", " << nty << ")";
const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty);
const bool use_idx = !IsNullArray(csr.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, {
SpMMCsrKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>
<<<nblks, nthrs, 0, thr_entry->stream>>>(
ufeat_data, efeat_data, out_data, argu_data, arge_data,
indptr, indices, edge_map,
csr.num_rows, csr.num_cols, efeat->shape[0],
ubcast_off, ebcast_off,
lhs_len, rhs_len, len
);
});
}
} // namespace cuda
} // namespace aten
} // namespace dgl
#endif
/*!
* Copyright (c) 2020 by Contributors
* \file array/kernel.cc
* \brief New kernels
*/
#include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <dgl/base_heterograph.h>
#include "kernel_decl.h"
#include "../c_api_common.h"
using namespace dgl::runtime;
namespace dgl {
namespace aten {
namespace {
// Check whether the given arguments have the same context.
inline void CheckCtx(
const DLContext& ctx,
const std::vector<NDArray>& arrays,
const std::vector<std::string>& names) {
for (size_t i = 0; i < arrays.size(); ++i) {
if (IsNullArray(arrays[i]))
continue;
CHECK_EQ(ctx, arrays[i]->ctx)
<< "Expected device context " << ctx << ". But got "
<< arrays[i]->ctx << " for " << names[i] << ".";
}
}
// Check whether input tensors are contiguous.
inline void CheckContiguous(
const std::vector<NDArray>& arrays,
const std::vector<std::string>& names) {
for (size_t i = 0; i < arrays.size(); ++i) {
if (IsNullArray(arrays[i]))
continue;
CHECK(arrays[i].IsContiguous())
<< "Expect " << names[i] << " to be a contiguous tensor";
}
}
// Check whether input tensors have valid shape.
inline void CheckShape(
const std::vector<uint64_t>& gdim,
const std::vector<int>& uev_idx,
const std::vector<NDArray>& arrays,
const std::vector<std::string>& names) {
for (size_t i = 0; i < arrays.size(); ++i) {
if (IsNullArray(arrays[i]))
continue;
CHECK_GE(arrays[i]->ndim, 2)
<< "Expect " << names[i] << " to have ndim >= 2, "
<< "Note that for scalar feature we expand its "
<< "dimension with an additional dimension of "
<< "length one.";
CHECK_EQ(gdim[uev_idx[i]], arrays[i]->shape[0])
<< "Expect " << names[i] << " to have size "
<< gdim[uev_idx[i]] << " on the first dimension, "
<< "but got " << arrays[i]->shape[0];
}
}
} // namespace
/*! \brief Generalized Sparse Matrix-Matrix Multiplication. */
void SpMM(const std::string& op, const std::string& reduce,
HeteroGraphPtr graph,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux,
SparseFormat format) {
// TODO(zihao): format tuning
format = SparseFormat::kCSR;
const auto& bcast = CalcBcastOff(op, ufeat, efeat);
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_TYPE_SWITCH(out->dtype, DType, "Feature data", {
if (format == SparseFormat::kCSR) {
SpMMCsr<XPU, IdType, DType>(
op, reduce, bcast, graph->GetCSCMatrix(0),
ufeat, efeat, out, out_aux);
} else if (format == SparseFormat::kCOO) {
SpMMCoo<XPU, IdType, DType>(
op, reduce, bcast, graph->GetCOOMatrix(0),
ufeat, efeat, out, out_aux);
} else {
LOG(FATAL) << "SpMM only supports CSR and COO foramts";
}
});
});
});
}
/*! \brief Generalized Sampled Dense-Dense Matrix Multiplication. */
void SDDMM(const std::string& op,
HeteroGraphPtr graph,
NDArray ufeat,
NDArray efeat,
NDArray out,
SparseFormat format) {
// TODO(zihao): format tuning
format = SparseFormat::kCOO;
const auto& bcast = CalcBcastOff(op, ufeat, efeat);
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_TYPE_SWITCH(out->dtype, DType, "Feature data", {
if (format == SparseFormat::kCSR) {
SDDMMCsr<XPU, IdType, DType>(
op, bcast, graph->GetCSRMatrix(0),
ufeat, efeat, out);
} else if (format == SparseFormat::kCOO) {
SDDMMCoo<XPU, IdType, DType>(
op, bcast, graph->GetCOOMatrix(0),
ufeat, efeat, out);
} else {
LOG(FATAL) << "SDDMM only supports CSR and COO foramts";
}
});
});
});
}
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
const std::string reduce_op = args[2];
NDArray U = args[3];
NDArray E = args[4];
NDArray V = args[5];
NDArray ArgU = args[6];
NDArray ArgE = args[7];
CheckCtx(graph->Context(), {U, E, V, ArgU, ArgE},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"});
CheckContiguous({U, E, V, ArgU, ArgE},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"});
CHECK_EQ(graph->NumEdgeTypes(), 1);
auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the graph.
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
CheckShape(
{graph->NumVertices(src_vtype), graph->NumEdges(0), graph->NumVertices(dst_vtype)},
{0, 1, 2, 2, 2},
{U, E, V, ArgU, ArgE},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"});
SpMM(op, reduce_op, graph.sptr(), U, E, V, {ArgU, ArgE}, SparseFormat::kAny);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
NDArray U = args[2];
NDArray V = args[3];
NDArray E = args[4];
CheckCtx(graph->Context(), {U, V, E}, {"U_data", "V_data", "E_data"});
CheckContiguous({U, V, E}, {"U_data", "V_data", "E_data"});
CHECK_EQ(graph->NumEdgeTypes(), 1);
auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the graph.
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
CheckShape(
{graph->NumVertices(src_vtype), graph->NumEdges(0), graph->NumVertices(dst_vtype)},
{0, 1, 2},
{U, E, V},
{"U_data", "E_data", "V_data"});
SDDMM(op, graph.sptr(), U, V, E, SparseFormat::kAny);
});
} // namespace aten
} // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment