Unverified Commit 272cb9e2 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[kernel] Select GE-SpMM when feature size is large. (#2306)

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd
parent 5f198763
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/ge_spmm.cuh
* \brief GE-SpMM CUDA kernel function header.
*/
#ifndef DGL_ARRAY_CUDA_GE_SPMM_CUH_
#define DGL_ARRAY_CUDA_GE_SPMM_CUH_
#include "macro.cuh"
#include "atomic.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
namespace dgl {
using namespace cuda;
namespace aten {
namespace cuda {
/*!
* \brief CUDA kernel of GE-SpMM on Csr.
* \note GE-SpMM: https://arxiv.org/pdf/2007.03179.pdf
* The grid dimension x and y are reordered for better performance.
*/
template <typename Idx, typename DType,
typename BinaryOp>
__global__ void GESpMMKernel(
const DType* __restrict__ ufeat,
const DType* __restrict__ efeat,
DType* __restrict__ out,
const Idx* __restrict__ indptr,
const Idx* __restrict__ indices,
const int64_t num_rows, const int64_t num_cols,
const int64_t feat_len) {
const Idx rid = blockIdx.x * blockDim.y + threadIdx.y; // over vertices dimension
const Idx fid = (blockIdx.y * 64) + threadIdx.x; // over feature dimension
if (rid < num_rows && fid < feat_len) {
const Idx low = __ldg(indptr + rid), high = __ldg(indptr + rid + 1);
DType accum_0 = 0.,
accum_1 = 0.;
if (blockIdx.y != gridDim.y - 1) {
for (Idx left = low; left < high; left += 32) {
if (left + 32 <= high) {
#pragma unroll
for (Idx i = 0; i < 32; ++i) {
const Idx eid = left + i;
const Idx cid = __ldg(indices + eid);
const Idx offset = feat_len * cid + fid;
if (BinaryOp::use_rhs) {
accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid);
accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid);
} else {
accum_0 += ufeat[offset];
accum_1 += ufeat[offset + 32];
}
}
} else {
for (Idx i = 0; left + i < high; ++i) {
const Idx eid = left + i;
const Idx cid = __ldg(indices + eid);
const Idx offset = feat_len * cid + fid;
if (BinaryOp::use_rhs) {
accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid);
accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid);
} else {
accum_0 += ufeat[offset];
accum_1 += ufeat[offset + 32];
}
}
}
out[feat_len * rid + fid] = accum_0;
out[feat_len * rid + fid + 32] = accum_1;
}
} else {
bool right_inbound = fid + 32 < feat_len;
for (int left = low; left < high; left += 32) {
if (left + 32 <= high) {
#pragma unroll
for (int i = 0; i < 32; ++i) {
const Idx eid = left + i;
const Idx cid = __ldg(indices + eid);
const Idx offset = feat_len * cid + fid;
if (BinaryOp::use_rhs) {
accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid);
accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid);
} else {
accum_0 += ufeat[offset];
accum_1 += ufeat[offset + 32];
}
}
} else {
for (int i = 0; i + left < high; ++i) {
const Idx eid = left + i;
const Idx cid = __ldg(indices + eid);
const Idx offset = feat_len * cid + fid;
if (BinaryOp::use_rhs) {
accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid);
accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid);
} else {
accum_0 += ufeat[offset];
accum_1 += ufeat[offset + 32];
}
}
}
out[feat_len * rid + fid] = accum_0;
if (right_inbound)
out[feat_len * rid + fid + 32] = accum_1;
}
}
}
}
template <typename Idx, typename DType,
typename BinaryOp>
void GESpMMCsr(
const CSRMatrix& csr,
NDArray ufeat, NDArray efeat,
NDArray out, int64_t feat_len) {
const Idx *indptr = csr.indptr.Ptr<Idx>();
const Idx *indices = csr.indices.Ptr<Idx>();
const DType *ufeat_data = ufeat.Ptr<DType>();
const DType *efeat_data = efeat.Ptr<DType>();
DType *out_data = out.Ptr<DType>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const int ntx = 32;
const int nty = 32;
const int nby = (feat_len + (ntx * 2) - 1) / (ntx * 2);
const int nbx = (csr.num_rows + nty - 1) / nty;
const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty);
const int sh_mem_size = 0;
CUDA_KERNEL_CALL((GESpMMKernel<Idx, DType, BinaryOp>),
nblks, nthrs, sh_mem_size, thr_entry->stream,
ufeat_data, efeat_data, out_data,
indptr, indices,
csr.num_rows, csr.num_cols,
feat_len);
}
} // namespace cuda
} // namespace aten
} // namespace dgl
#endif
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "./spmm.cuh" #include "./spmm.cuh"
#include "./ge_spmm.cuh"
#include "./functor.cuh" #include "./functor.cuh"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
...@@ -238,8 +239,19 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -238,8 +239,19 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
NDArray efeat, NDArray efeat,
NDArray out, NDArray out,
std::vector<NDArray> out_aux) { std::vector<NDArray> out_aux) {
int64_t feat_len = bcast.out_len;
bool is_scalar_efeat = efeat.NumElements() == csr.indices->shape[0];
bool use_efeat = op != "copy_lhs";
if (reduce == "sum") { if (reduce == "sum") {
if (sizeof(IdType) == 4 && op == "copy_lhs") { if ((!use_efeat || is_scalar_efeat) && feat_len > 64) { // ge-spmm
if (use_efeat && !IsNullArray(csr.data)) // reorder edge data
efeat = IndexSelect(efeat, csr.data);
SWITCH_OP(op, Op, {
cuda::GESpMMCsr<IdType, DType, Op>(
csr, ufeat, efeat, out, feat_len);
});
} else if (sizeof(IdType) == 4 && op == "copy_lhs") { // cusparse
int64_t x_length = 1; int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i) for (int i = 1; i < ufeat->ndim; ++i)
x_length *= ufeat->shape[i]; x_length *= ufeat->shape[i];
...@@ -249,7 +261,7 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -249,7 +261,7 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
nullptr, nullptr,
static_cast<DType*>(out->data), static_cast<DType*>(out->data),
x_length); x_length);
} else if (sizeof(IdType) == 4 && op == "mul" && efeat.NumElements() == csr.indices->shape[0]) { } else if (sizeof(IdType) == 4 && op == "mul" && is_scalar_efeat) { // cusparse
int64_t x_length = 1; int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i) for (int i = 1; i < ufeat->ndim; ++i)
x_length *= ufeat->shape[i]; x_length *= ufeat->shape[i];
...@@ -261,7 +273,7 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -261,7 +273,7 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
static_cast<DType*>(efeat->data), static_cast<DType*>(efeat->data),
static_cast<DType*>(out->data), static_cast<DType*>(out->data),
x_length); x_length);
} else { } else { // general kernel
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >( cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
bcast, csr, ufeat, efeat, out, NullArray(), NullArray()); bcast, csr, ufeat, efeat, out, NullArray(), NullArray());
......
...@@ -150,7 +150,7 @@ __global__ void SpMMCsrKernel( ...@@ -150,7 +150,7 @@ __global__ void SpMMCsrKernel(
const Idx* __restrict__ indptr, const Idx* __restrict__ indptr,
const Idx* __restrict__ indices, const Idx* __restrict__ indices,
const Idx* __restrict__ edge_map, const Idx* __restrict__ edge_map,
int64_t num_rows, int64_t num_cols, int64_t nnz, int64_t num_rows, int64_t num_cols,
const int64_t* __restrict__ ubcast_off, const int64_t* __restrict__ ubcast_off,
const int64_t* __restrict__ ebcast_off, const int64_t* __restrict__ ebcast_off,
int64_t ufeat_len, int64_t efeat_len, int64_t out_len) { int64_t ufeat_len, int64_t efeat_len, int64_t out_len) {
...@@ -306,7 +306,7 @@ void SpMMCsr( ...@@ -306,7 +306,7 @@ void SpMMCsr(
nblks, nthrs, 0, thr_entry->stream, nblks, nthrs, 0, thr_entry->stream,
ufeat_data, efeat_data, out_data, argu_data, arge_data, ufeat_data, efeat_data, out_data, argu_data, arge_data,
indptr, indices, edge_map, indptr, indices, edge_map,
csr.num_rows, csr.num_cols, efeat->shape[0], csr.num_rows, csr.num_cols,
ubcast_off, ebcast_off, ubcast_off, ebcast_off,
lhs_len, rhs_len, len) lhs_len, rhs_len, len)
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment