Unverified Commit 071cba1f authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[kernel] New SpMM & SDDMM kernel on CPU and CUDA (#1644)

* udp

* simplify

* sddmm dot cpu

* upd

* format

* upd

* compatible with MJ's PR

* lint

* upd

* upd

* upd

* python end

* upd

* upd

* lint

* lint

* upd

* upd

* upd

* upd

* upd

* lint

* fix mxnet

* upd

* lint

* use minjie's ptr

* macro

* upd

* reorg

* lint

* fix corner cases

* upd

* enrich cpu docs

* upd

* upd

* lint

* lint

* pylint

* sx review

* improve docstring

* python doc

* upd

* restructure

* lint

* upd test

* upd

* pylint

* fix corner cases and test
parent 90a103e7
/*!
* Copyright (c) 2020 by Contributors
* \file array/kernel_decl.h
* \brief Sparse matrix format-specific operator declarations.
*/
#ifndef DGL_ARRAY_KERNEL_DECL_H_
#define DGL_ARRAY_KERNEL_DECL_H_
#include <dgl/array.h>
#include <dgl/bcast.h>
#include <dgl/base_heterograph.h>
#include <string>
#include <vector>
namespace dgl {
namespace aten {
/*!
* \brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format.
*/
template <int XPU, typename IdType, typename DType>
void SpMMCsr(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const aten::CSRMatrix& csr,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux);
/*!
* \brief Generalized Sparse Matrix Dense Matrix Multiplication on Coo format.
*/
template <int XPU, typename IdType, typename DType>
void SpMMCoo(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const aten::COOMatrix& coo,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux);
/*!
* \brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format.
*/
template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op,
const BcastOff& bcast,
const aten::CSRMatrix& csr,
NDArray ufeat,
NDArray efeat,
NDArray out);
/*!
* \brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format.
*/
template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op,
const BcastOff& bcast,
const aten::COOMatrix& coo,
NDArray ufeat,
NDArray efeat,
NDArray out);
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_KERNEL_DECL_H_
/*!
* Copyright (c) 2020 by Contributors
* \file kernel/bcast.h
* \brief Broadcast related function implementations.
*/
#include <dgl/bcast.h>
#include <dmlc/logging.h>
#include <algorithm>
namespace dgl {
namespace {
/*!
* \brief Determine whether use broadcasting or not, given the operator
* type, lhs array and rhs array.
*/
bool UseBcast(const std::string& op, NDArray lhs, NDArray rhs) {
if (op == "copy_u" || op == "copy_e")
return false; // broadcasting is not required for copy_u/copy_e
if (lhs->ndim != rhs->ndim)
return true;
for (int i = 1; i < lhs->ndim; ++i) {
if (lhs->shape[i] != rhs->shape[i])
return true;
}
return false;
}
} // namespace
/*!
* \brief: Compute broadcast and auxiliary information given operator
* and operands for kernel computation.
* \note: Expect lhs, rhs to have ndim >= 2 and the shape of lhs/rhs
* valid for the op computation.
*/
BcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs) {
BcastOff rst;
rst.lhs_len = 1;
rst.rhs_len = 1;
for (int i = 1; i < lhs->ndim; ++i)
rst.lhs_len *= lhs->shape[i];
for (int i = 1; i < rhs->ndim; ++i)
rst.rhs_len *= rhs->shape[i];
rst.use_bcast = UseBcast(op, lhs, rhs);
rst.reduce_size = 1; // defaults to 1, except for the case op == 'dot'.
if (rst.use_bcast) {
const int max_ndim = std::max(lhs->ndim, rhs->ndim) - 1;
int out_len = 1, j = 0;
if (op == "dot") {
rst.reduce_size = lhs->shape[lhs->ndim - 1]; // set reduce_size for dot.
++j; // do not consider reduce axis in computing lhs_offset and rhs_offset.
}
int stride_l = 1, stride_r = 1;
rst.lhs_offset.push_back(0); // lhs_offset[0] is always 0
rst.rhs_offset.push_back(0); // rhs_offset[0] is always 0
for (; j < max_ndim; ++j) { // iterate the axis from back to front.
// dl refers to the size of lhs array in the current axis, likewise for dr.
const int dl = (lhs->ndim - 1 - j < 1) ? 1 : lhs->shape[lhs->ndim - 1 - j];
const int dr = (rhs->ndim - 1 - j < 1) ? 1 : rhs->shape[rhs->ndim - 1 - j];
for (int i = 1; i < std::max(dl, dr); ++i) {
for (int k = 0; k < out_len; ++k) {
/* Explaination:
* if current dimension is not broadcast dimension for lhs array
* lhs_offset[i * out_len + k] = lhs_offset[k] + i * stride_l
* else
* lhs_offset[i * out_len + k] = lhs_offset[k]
* likewise for rhs_offset.
*/
rst.lhs_offset.push_back(rst.lhs_offset[k] + i * (i < dl) * stride_l);
rst.rhs_offset.push_back(rst.rhs_offset[k] + i * (i < dr) * stride_r);
}
}
out_len *= std::max(dl, dr);
stride_l *= dl;
stride_r *= dr;
}
rst.out_len = out_len;
} else {
rst.out_len = (op == "copy_e") ? rst.rhs_len : rst.lhs_len;
if (op == "dot") {
rst.reduce_size = lhs->shape[lhs->ndim - 1]; // set reduce_size for dot.
rst.out_len /= rst.reduce_size; // out_len is divied by reduce_size in dot.
}
}
#ifdef DEBUG
LOG(INFO) << "lhs_len: " << rst.lhs_len << " " <<
"rhs_len: " << rst.rhs_len << " " <<
"out_len: " << rst.out_len << " " <<
"reduce_size: " << rst.reduce_size << std::endl;
#endif
return rst;
}
} // namespace dgl
...@@ -30,6 +30,44 @@ inline int FindNumThreads(int dim, int max_nthrs = CUDA_MAX_NUM_THREADS) { ...@@ -30,6 +30,44 @@ inline int FindNumThreads(int dim, int max_nthrs = CUDA_MAX_NUM_THREADS) {
return ret; return ret;
} }
/*
* !\brief Find number of blocks is smaller than nblks and max_nblks
* on the given axis ('x', 'y' or 'z').
*/
template <char axis>
inline int FindNumBlocks(int nblks, int max_nblks = -1) {
int default_max_nblks = -1;
switch (axis) {
case 'x':
default_max_nblks = CUDA_MAX_NUM_BLOCKS_X;
break;
case 'y':
default_max_nblks = CUDA_MAX_NUM_BLOCKS_Y;
break;
case 'z':
default_max_nblks = CUDA_MAX_NUM_BLOCKS_Z;
break;
default:
LOG(FATAL) << "Axis " << axis << " not recognized";
break;
}
if (max_nblks == -1)
max_nblks = default_max_nblks;
CHECK_NE(nblks, 0);
if (nblks < max_nblks)
return nblks;
return max_nblks;
}
template <typename T>
__device__ __forceinline__ T _ldg(T* addr) {
#if __CUDA_ARCH__ >= 350
return __ldg(addr);
#else
return *addr;
#endif
}
} // namespace cuda } // namespace cuda
} // namespace dgl } // namespace dgl
......
...@@ -130,7 +130,7 @@ struct SwitchSrcDst<SelectDst> { ...@@ -130,7 +130,7 @@ struct SwitchSrcDst<SelectDst> {
// common binary functors // common binary functors
template <typename DType> template <typename DType>
struct BinaryAdd { struct BinaryAdd {
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) { static DGLDEVICE DGLINLINE DType Call(const DType *lhs, const DType *rhs, int64_t len) {
return lhs[0] + rhs[0]; return lhs[0] + rhs[0];
} }
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) { static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
...@@ -143,7 +143,7 @@ struct BinaryAdd { ...@@ -143,7 +143,7 @@ struct BinaryAdd {
template <typename DType> template <typename DType>
struct BinaryMul { struct BinaryMul {
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) { static DGLDEVICE DGLINLINE DType Call(const DType *lhs, const DType *rhs, int64_t len) {
return lhs[0] * rhs[0]; return lhs[0] * rhs[0];
} }
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) { static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
...@@ -156,7 +156,7 @@ struct BinaryMul { ...@@ -156,7 +156,7 @@ struct BinaryMul {
template <typename DType> template <typename DType>
struct BinarySub { struct BinarySub {
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) { static DGLDEVICE DGLINLINE DType Call(const DType *lhs, const DType *rhs, int64_t len) {
return lhs[0] - rhs[0]; return lhs[0] - rhs[0];
} }
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) { static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
...@@ -169,7 +169,7 @@ struct BinarySub { ...@@ -169,7 +169,7 @@ struct BinarySub {
template <typename DType> template <typename DType>
struct BinaryDiv { struct BinaryDiv {
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) { static DGLDEVICE DGLINLINE DType Call(const DType *lhs, const DType *rhs, int64_t len) {
return lhs[0] / rhs[0]; return lhs[0] / rhs[0];
} }
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) { static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
...@@ -182,7 +182,7 @@ struct BinaryDiv { ...@@ -182,7 +182,7 @@ struct BinaryDiv {
template <typename DType> template <typename DType>
struct BinaryUseLhs { struct BinaryUseLhs {
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) { static DGLDEVICE DGLINLINE DType Call(const DType *lhs, const DType *rhs, int64_t len) {
return lhs[0]; return lhs[0];
} }
static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) { static DGLDEVICE DGLINLINE DType BackwardLhs(DType lhs, DType rhs, DType out) {
...@@ -195,7 +195,7 @@ struct BinaryUseLhs { ...@@ -195,7 +195,7 @@ struct BinaryUseLhs {
template <typename DType> template <typename DType>
struct BinaryDot { struct BinaryDot {
static DGLDEVICE DGLINLINE DType Call(DType *lhs, DType *rhs, int64_t len) { static DGLDEVICE DGLINLINE DType Call(const DType *lhs, const DType *rhs, int64_t len) {
DType out = 0; DType out = 0;
// simple vector dot vector // simple vector dot vector
#pragma unroll #pragma unroll
......
...@@ -22,6 +22,12 @@ int FindNumThreads(int dim, int max_nthrs) { ...@@ -22,6 +22,12 @@ int FindNumThreads(int dim, int max_nthrs) {
return ret; return ret;
} }
int FindNumBlocks(int nblks, int max_nblks) {
if (nblks < max_nblks)
return nblks;
return max_nblks;
}
int64_t ComputeXLength(runtime::NDArray feat_array) { int64_t ComputeXLength(runtime::NDArray feat_array) {
int64_t ret = 1; int64_t ret = 1;
for (int i = 1; i < feat_array->ndim; ++i) { for (int i = 1; i < feat_array->ndim; ++i) {
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include <minigun/csr.h> #include <minigun/csr.h>
#include <dlpack/dlpack.h> #include <dlpack/dlpack.h>
#include <dgl/runtime/ndarray.h> #include <dgl/runtime/ndarray.h>
#include <dgl/array.h>
#include <cstdlib> #include <cstdlib>
#include <vector> #include <vector>
...@@ -23,6 +23,11 @@ namespace utils { ...@@ -23,6 +23,11 @@ namespace utils {
*/ */
int FindNumThreads(int dim, int max_nthrs); int FindNumThreads(int dim, int max_nthrs);
/*
* !\brief Find number of blocks is smaller than nblks and max_nblks
*/
int FindNumBlocks(int nblks, int max_nblks);
/* /*
* !\brief Compute the total number of feature elements. * !\brief Compute the total number of feature elements.
*/ */
......
...@@ -95,7 +95,7 @@ def div(a, b): ...@@ -95,7 +95,7 @@ def div(a, b):
"""Compute a / b""" """Compute a / b"""
pass pass
def sum(x, dim): def sum(x, dim, keepdims=False):
"""Computes the sum of array elements over given axes""" """Computes the sum of array elements over given axes"""
pass pass
......
...@@ -71,8 +71,8 @@ def mul(a, b): ...@@ -71,8 +71,8 @@ def mul(a, b):
def div(a, b): def div(a, b):
return a / b return a / b
def sum(x, dim): def sum(x, dim, keepdims=False):
return x.sum(dim) return x.sum(dim, keepdims=keepdims)
def max(x, dim): def max(x, dim):
return x.max(dim) return x.max(dim)
......
...@@ -70,8 +70,8 @@ def mul(a, b): ...@@ -70,8 +70,8 @@ def mul(a, b):
def div(a, b): def div(a, b):
return a / b return a / b
def sum(x, dim): def sum(x, dim, keepdims=False):
return x.sum(dim) return x.sum(dim, keepdims=keepdims)
def max(x, dim): def max(x, dim):
return x.max(dim)[0] return x.max(dim)[0]
......
...@@ -158,8 +158,8 @@ def div(a, b): ...@@ -158,8 +158,8 @@ def div(a, b):
return a / b return a / b
def sum(x, dim): def sum(x, dim, keepdims=False):
return tf.reduce_sum(x, axis=dim) return tf.reduce_sum(x, axis=dim, keepdims=keepdims)
def max(x, dim): def max(x, dim):
......
import dgl
import pytest
import networkx as nx
import backend as F
import numpy as np
np.random.seed(42)
dgl.random.seed(42)
def _unsqueeze_if_scalar(x): # used in udf, to unsqueeze the feature if it's scalar
return x if F.ndim(x) > 1 else F.unsqueeze(x, -1)
def _rand_operand_1(shp):
return F.tensor(np.random.rand(*shp))
def _rand_operand_2(shp): # for division op, the divisor should be greater than 1
return F.tensor(np.random.rand(*shp) + 1)
udf_msg = {
'add': lambda edges: {'m': edges.src['x'] + edges.data['w']},
'sub': lambda edges: {'m': edges.src['x'] - edges.data['w']},
'mul': lambda edges: {'m': edges.src['x'] * edges.data['w']},
'div': lambda edges: {'m': edges.src['x'] / edges.data['w']},
'copy_u': lambda edges: {'m': edges.src['x']},
'copy_e': lambda edges: {'m': edges.data['w']}
}
udf_apply_edges = {
'add': lambda edges: {'m': edges.src['x'] + edges.dst['y']},
'sub': lambda edges: {'m': edges.src['x'] - edges.dst['y']},
'mul': lambda edges: {'m': edges.src['x'] * edges.dst['y']},
'div': lambda edges: {'m': edges.src['x'] / edges.dst['y']},
'dot': lambda edges: {'m': F.sum(edges.src['x'] * edges.dst['y'], -1, keepdims=True)},
'copy_u': lambda edges: {'m': edges.src['x']},
}
udf_reduce = {
'sum': lambda nodes: {'v': F.sum(nodes.mailbox['m'], 1)},
'min': lambda nodes: {'v': F.min(nodes.mailbox['m'], 1)},
'max': lambda nodes: {'v': F.max(nodes.mailbox['m'], 1)}
}
graphs = [
dgl.rand_graph(30, 0),
dgl.rand_graph(100, 30),
dgl.rand_graph(100, 3000),
dgl.rand_bipartite(80, 160, 3000)
]
spmm_shapes = [
((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)),
((5, 3, 1, 7), (1, 3, 7, 1)),
((1, 3, 1), (4, 1, 3)),
((3, 3), (1, 3)),
((), (3,)),
((3,), ()),
((), ())
]
sddmm_shapes = [
((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)),
((5, 3, 1, 7), (1, 3, 7, 7)),
((1, 3, 3), (4, 1, 3)),
((3, 3), (1, 3)),
((3,), (3,)),
((), ())
]
@pytest.mark.parametrize('g', graphs)
@pytest.mark.parametrize('shp', spmm_shapes)
@pytest.mark.parametrize('msg', ['add', 'sub', 'mul', 'div', 'copy_u', 'copy_e'])
@pytest.mark.parametrize('reducer', ['sum', 'min', 'max'])
def test_spmm(g, shp, msg, reducer):
print(g)
u = _rand_operand_1((g.number_of_src_nodes(),) + shp[0])
e = _rand_operand_2((g.number_of_edges(),) + shp[1])
print('u shape: {}, e shape: {}'.format(F.shape(u), F.shape(e)))
g.srcdata['x'] = _unsqueeze_if_scalar(u)
g.edata['w'] = _unsqueeze_if_scalar(e)
print('SpMM(message func: {}, reduce func: {})'.format(msg, reducer))
v = dgl.gspmm(g, msg, reducer, u, e)[0]
non_degree_indices = F.tensor(
np.nonzero(F.asnumpy(g.in_degrees()) != 0)[0])
v = F.gather_row(v, non_degree_indices)
g.update_all(udf_msg[msg], udf_reduce[reducer])
if 'v' in g.dstdata:
v1 = F.gather_row(g.dstdata['v'], non_degree_indices)
assert F.allclose(v, v1, rtol=1e-3, atol=1e-3)
print('passed')
g.srcdata.pop('x')
g.edata.pop('w')
if 'v' in g.dstdata: g.dstdata.pop('v')
@pytest.mark.parametrize('g', graphs)
@pytest.mark.parametrize('shp', sddmm_shapes)
@pytest.mark.parametrize('msg', ['add', 'sub', 'mul', 'div', 'dot', 'copy_u'])
def test_sddmm(g, shp, msg):
if dgl.backend.backend_name == 'mxnet' and g.number_of_edges() == 0:
pytest.skip() # mxnet do not support zero shape tensor
print(g)
u = _rand_operand_1((g.number_of_src_nodes(),) + shp[0])
v = _rand_operand_2((g.number_of_dst_nodes(),) + shp[1])
print('u shape: {}, v shape: {}'.format(F.shape(u), F.shape(v)))
g.srcdata['x'] = _unsqueeze_if_scalar(u)
g.dstdata['y'] = _unsqueeze_if_scalar(v)
print('SDDMM(message func: {})'.format(msg))
e = dgl.gsddmm(g, msg, u, v)
g.apply_edges(udf_apply_edges[msg])
if 'm' in g.edata:
e1 = g.edata['m']
assert F.allclose(e, e1, rtol=1e-3, atol=1e-3)
print('passed')
g.srcdata.pop('x')
g.dstdata.pop('y')
if 'm' in g.edata: g.edata.pop('m')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment