ge_spmm.cuh 4.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/*!
 * Copyright (c) 2020 by Contributors
 * \file array/cuda/ge_spmm.cuh
 * \brief GE-SpMM CUDA kernel function header.
 */
#ifndef DGL_ARRAY_CUDA_GE_SPMM_CUH_
#define DGL_ARRAY_CUDA_GE_SPMM_CUH_

#include "macro.cuh"
#include "atomic.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"

namespace dgl {

using namespace cuda;

namespace aten {
namespace cuda {

21
/*!
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
 * \brief CUDA kernel of GE-SpMM on Csr.
 * \note GE-SpMM: https://arxiv.org/pdf/2007.03179.pdf
 *       The grid dimension x and y are reordered for better performance.
 */
template <typename Idx, typename DType,
          typename BinaryOp>
__global__ void GESpMMKernel(
    const DType* __restrict__ ufeat,
    const DType* __restrict__ efeat,
    DType* __restrict__ out,
    const Idx* __restrict__ indptr,
    const Idx* __restrict__ indices,
    const int64_t num_rows, const int64_t num_cols,
    const int64_t feat_len) {
  const Idx rid = blockIdx.x * blockDim.y + threadIdx.y;  // over vertices dimension
  const Idx fid = (blockIdx.y * 64) + threadIdx.x;        // over feature dimension

  if (rid < num_rows && fid < feat_len) {
    const Idx low = __ldg(indptr + rid), high = __ldg(indptr + rid + 1);
    DType accum_0 = 0.,
          accum_1 = 0.;

Zihao Ye's avatar
Zihao Ye committed
44
    if (blockIdx.y != gridDim.y - 1) {  // fid + 32 < feat_len
45
46
47
48
      for (Idx left = low; left < high; left += 32) {
        if (left + 32 <= high) {
#pragma unroll
          for (Idx i = 0; i < 32; ++i) {
49
            const Idx eid = left + i;
50
51
52
53
54
55
56
57
58
59
60
61
            const Idx cid = __ldg(indices + eid);
            const Idx offset = feat_len * cid + fid;
            if (BinaryOp::use_rhs) {
              accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid);
              accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid);
            } else {
              accum_0 += ufeat[offset];
              accum_1 += ufeat[offset + 32];
            }
          }
        } else {
          for (Idx i = 0; left + i < high; ++i) {
62
            const Idx eid = left + i;
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
            const Idx cid = __ldg(indices + eid);
            const Idx offset = feat_len * cid + fid;
            if (BinaryOp::use_rhs) {
              accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid);
              accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid);
            } else {
              accum_0 += ufeat[offset];
              accum_1 += ufeat[offset + 32];
            }
          }
        }

        out[feat_len * rid + fid] = accum_0;
        out[feat_len * rid + fid + 32] = accum_1;
      }
    } else {
Zihao Ye's avatar
Zihao Ye committed
79
80
      const Idx fid_0 = fid < feat_len ? fid : 0,
                fid_1 = fid + 32 < feat_len ? fid + 32 : 0;
81
82
83
84
      for (int left = low; left < high; left += 32) {
        if (left + 32 <= high) {
#pragma unroll
          for (int i = 0; i < 32; ++i) {
85
86
            const Idx eid = left + i;
            const Idx cid = __ldg(indices + eid);
Zihao Ye's avatar
Zihao Ye committed
87
            const Idx offset = feat_len * cid;
88
            if (BinaryOp::use_rhs) {
Zihao Ye's avatar
Zihao Ye committed
89
90
              accum_0 += BinaryOp::Call(ufeat + offset + fid_0, efeat + eid);
              accum_1 += BinaryOp::Call(ufeat + offset + fid_1, efeat + eid);
91
            } else {
Zihao Ye's avatar
Zihao Ye committed
92
93
              accum_0 += ufeat[offset + fid_0];
              accum_1 += ufeat[offset + fid_1];
94
95
96
97
            }
          }
        } else {
          for (int i = 0; i + left < high; ++i) {
98
99
            const Idx eid = left + i;
            const Idx cid = __ldg(indices + eid);
Zihao Ye's avatar
Zihao Ye committed
100
            const Idx offset = feat_len * cid;
101
            if (BinaryOp::use_rhs) {
Zihao Ye's avatar
Zihao Ye committed
102
103
              accum_0 += BinaryOp::Call(ufeat + offset + fid_0, efeat + eid);
              accum_1 += BinaryOp::Call(ufeat + offset + fid_1, efeat + eid);
104
            } else {
Zihao Ye's avatar
Zihao Ye committed
105
106
              accum_0 += ufeat[offset + fid_0];
              accum_1 += ufeat[offset + fid_1];
107
108
109
110
111
            }
          }
        }

        out[feat_len * rid + fid] = accum_0;
Zihao Ye's avatar
Zihao Ye committed
112
        if (fid + 32 < feat_len)
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
          out[feat_len * rid + fid + 32] = accum_1;
      }
    }
  }
}

template <typename Idx, typename DType,
          typename BinaryOp>
void GESpMMCsr(
    const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat,
    NDArray out, int64_t feat_len) {
  const Idx *indptr = csr.indptr.Ptr<Idx>();
  const Idx *indices = csr.indices.Ptr<Idx>();
  const DType *ufeat_data = ufeat.Ptr<DType>();
  const DType *efeat_data = efeat.Ptr<DType>();
  DType *out_data = out.Ptr<DType>();

131
  cudaStream_t stream = runtime::getCurrentCUDAStream();
132

133
134
135
136
137
138
139
140
141
  const int ntx = 32;
  const int nty = 32;
  const int nby = (feat_len + (ntx * 2) - 1) / (ntx * 2);
  const int nbx = (csr.num_rows + nty - 1) / nty;
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
  const int sh_mem_size = 0;

  CUDA_KERNEL_CALL((GESpMMKernel<Idx, DType, BinaryOp>),
142
      nblks, nthrs, sh_mem_size, stream,
143
144
145
146
147
148
149
150
151
152
      ufeat_data, efeat_data, out_data,
      indptr, indices,
      csr.num_rows, csr.num_cols,
      feat_len);
}

}  // namespace cuda
}  // namespace aten
}  // namespace dgl

153
#endif  // DGL_ARRAY_CUDA_GE_SPMM_CUH_