ge_spmm.cuh 4.83 KB
Newer Older
1
2
/*!
 * Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/ge_spmm.cuh
 * @brief GE-SpMM CUDA kernel function header.
5
6
7
8
9
10
 */
#ifndef DGL_ARRAY_CUDA_GE_SPMM_CUH_
#define DGL_ARRAY_CUDA_GE_SPMM_CUH_

#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
11
12
#include "atomic.cuh"
#include "macro.cuh"
13
14
15
16
17
18
19
20

namespace dgl {

using namespace cuda;

namespace aten {
namespace cuda {

21
/*!
22
23
 * @brief CUDA kernel of GE-SpMM on Csr.
 * @note GE-SpMM: https://arxiv.org/pdf/2007.03179.pdf
24
25
 *       The grid dimension x and y are reordered for better performance.
 */
26
template <typename Idx, typename DType, typename BinaryOp>
27
__global__ void GESpMMKernel(
28
29
30
31
32
33
34
    const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
    DType* __restrict__ out, const Idx* __restrict__ indptr,
    const Idx* __restrict__ indices, const int64_t num_rows,
    const int64_t num_cols, const int64_t feat_len) {
  const Idx rid =
      blockIdx.x * blockDim.y + threadIdx.y;        // over vertices dimension
  const Idx fid = (blockIdx.y * 64) + threadIdx.x;  // over feature dimension
35
36
37

  if (rid < num_rows && fid < feat_len) {
    const Idx low = __ldg(indptr + rid), high = __ldg(indptr + rid + 1);
38
    DType accum_0 = 0., accum_1 = 0.;
39

Zihao Ye's avatar
Zihao Ye committed
40
    if (blockIdx.y != gridDim.y - 1) {  // fid + 32 < feat_len
41
42
43
44
      for (Idx left = low; left < high; left += 32) {
        if (left + 32 <= high) {
#pragma unroll
          for (Idx i = 0; i < 32; ++i) {
45
            const Idx eid = left + i;
46
47
48
49
50
51
52
53
54
55
56
57
            const Idx cid = __ldg(indices + eid);
            const Idx offset = feat_len * cid + fid;
            if (BinaryOp::use_rhs) {
              accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid);
              accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid);
            } else {
              accum_0 += ufeat[offset];
              accum_1 += ufeat[offset + 32];
            }
          }
        } else {
          for (Idx i = 0; left + i < high; ++i) {
58
            const Idx eid = left + i;
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
            const Idx cid = __ldg(indices + eid);
            const Idx offset = feat_len * cid + fid;
            if (BinaryOp::use_rhs) {
              accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid);
              accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid);
            } else {
              accum_0 += ufeat[offset];
              accum_1 += ufeat[offset + 32];
            }
          }
        }

        out[feat_len * rid + fid] = accum_0;
        out[feat_len * rid + fid + 32] = accum_1;
      }
    } else {
Zihao Ye's avatar
Zihao Ye committed
75
76
      const Idx fid_0 = fid < feat_len ? fid : 0,
                fid_1 = fid + 32 < feat_len ? fid + 32 : 0;
77
78
79
80
      for (int left = low; left < high; left += 32) {
        if (left + 32 <= high) {
#pragma unroll
          for (int i = 0; i < 32; ++i) {
81
82
            const Idx eid = left + i;
            const Idx cid = __ldg(indices + eid);
Zihao Ye's avatar
Zihao Ye committed
83
            const Idx offset = feat_len * cid;
84
            if (BinaryOp::use_rhs) {
Zihao Ye's avatar
Zihao Ye committed
85
86
              accum_0 += BinaryOp::Call(ufeat + offset + fid_0, efeat + eid);
              accum_1 += BinaryOp::Call(ufeat + offset + fid_1, efeat + eid);
87
            } else {
Zihao Ye's avatar
Zihao Ye committed
88
89
              accum_0 += ufeat[offset + fid_0];
              accum_1 += ufeat[offset + fid_1];
90
91
92
93
            }
          }
        } else {
          for (int i = 0; i + left < high; ++i) {
94
95
            const Idx eid = left + i;
            const Idx cid = __ldg(indices + eid);
Zihao Ye's avatar
Zihao Ye committed
96
            const Idx offset = feat_len * cid;
97
            if (BinaryOp::use_rhs) {
Zihao Ye's avatar
Zihao Ye committed
98
99
              accum_0 += BinaryOp::Call(ufeat + offset + fid_0, efeat + eid);
              accum_1 += BinaryOp::Call(ufeat + offset + fid_1, efeat + eid);
100
            } else {
Zihao Ye's avatar
Zihao Ye committed
101
102
              accum_0 += ufeat[offset + fid_0];
              accum_1 += ufeat[offset + fid_1];
103
104
105
106
107
            }
          }
        }

        out[feat_len * rid + fid] = accum_0;
108
        if (fid + 32 < feat_len) out[feat_len * rid + fid + 32] = accum_1;
109
110
111
112
113
      }
    }
  }
}

114
template <typename Idx, typename DType, typename BinaryOp>
115
void GESpMMCsr(
116
117
118
119
120
121
122
    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
    int64_t feat_len) {
  const Idx* indptr = csr.indptr.Ptr<Idx>();
  const Idx* indices = csr.indices.Ptr<Idx>();
  const DType* ufeat_data = ufeat.Ptr<DType>();
  const DType* efeat_data = efeat.Ptr<DType>();
  DType* out_data = out.Ptr<DType>();
123

124
  cudaStream_t stream = runtime::getCurrentCUDAStream();
125

126
127
128
129
130
131
132
133
  const int ntx = 32;
  const int nty = 32;
  const int nby = (feat_len + (ntx * 2) - 1) / (ntx * 2);
  const int nbx = (csr.num_rows + nty - 1) / nty;
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
  const int sh_mem_size = 0;

134
135
136
137
  CUDA_KERNEL_CALL(
      (GESpMMKernel<Idx, DType, BinaryOp>), nblks, nthrs, sh_mem_size, stream,
      ufeat_data, efeat_data, out_data, indptr, indices, csr.num_rows,
      csr.num_cols, feat_len);
138
139
140
141
142
143
}

}  // namespace cuda
}  // namespace aten
}  // namespace dgl

144
#endif  // DGL_ARRAY_CUDA_GE_SPMM_CUH_