ge_spmm.cuh 4.87 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
/*!
 * Copyright (c) 2020 by Contributors
 * \file array/cuda/ge_spmm.cuh
 * \brief GE-SpMM CUDA kernel function header.
 */
#ifndef DGL_ARRAY_CUDA_GE_SPMM_CUH_
#define DGL_ARRAY_CUDA_GE_SPMM_CUH_

#include "macro.cuh"
#include "atomic.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"

namespace dgl {

using namespace cuda;

namespace aten {
namespace cuda {

/*! 
 * \brief CUDA kernel of GE-SpMM on Csr.
 * \note GE-SpMM: https://arxiv.org/pdf/2007.03179.pdf
 *       The grid dimension x and y are reordered for better performance.
 */
template <typename Idx, typename DType,
          typename BinaryOp>
__global__ void GESpMMKernel(
    const DType* __restrict__ ufeat,
    const DType* __restrict__ efeat,
    DType* __restrict__ out,
    const Idx* __restrict__ indptr,
    const Idx* __restrict__ indices,
    const int64_t num_rows, const int64_t num_cols,
    const int64_t feat_len) {
  const Idx rid = blockIdx.x * blockDim.y + threadIdx.y;  // over vertices dimension
  const Idx fid = (blockIdx.y * 64) + threadIdx.x;        // over feature dimension

  if (rid < num_rows && fid < feat_len) {
    const Idx low = __ldg(indptr + rid), high = __ldg(indptr + rid + 1);
    DType accum_0 = 0.,
          accum_1 = 0.;

Zihao Ye's avatar
Zihao Ye committed
44
    if (blockIdx.y != gridDim.y - 1) {  // fid + 32 < feat_len
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
      for (Idx left = low; left < high; left += 32) {
        if (left + 32 <= high) {
#pragma unroll
          for (Idx i = 0; i < 32; ++i) {
            const Idx eid = left + i; 
            const Idx cid = __ldg(indices + eid);
            const Idx offset = feat_len * cid + fid;
            if (BinaryOp::use_rhs) {
              accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid);
              accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid);
            } else {
              accum_0 += ufeat[offset];
              accum_1 += ufeat[offset + 32];
            }
          }
        } else {
          for (Idx i = 0; left + i < high; ++i) {
            const Idx eid = left + i; 
            const Idx cid = __ldg(indices + eid);
            const Idx offset = feat_len * cid + fid;
            if (BinaryOp::use_rhs) {
              accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid);
              accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid);
            } else {
              accum_0 += ufeat[offset];
              accum_1 += ufeat[offset + 32];
            }
          }
        }

        out[feat_len * rid + fid] = accum_0;
        out[feat_len * rid + fid + 32] = accum_1;
      }
    } else {
Zihao Ye's avatar
Zihao Ye committed
79
80
      const Idx fid_0 = fid < feat_len ? fid : 0,
                fid_1 = fid + 32 < feat_len ? fid + 32 : 0;
81
82
83
84
85
86
      for (int left = low; left < high; left += 32) {
        if (left + 32 <= high) {
#pragma unroll
          for (int i = 0; i < 32; ++i) {
            const Idx eid = left + i; 
            const Idx cid = __ldg(indices + eid); 
Zihao Ye's avatar
Zihao Ye committed
87
            const Idx offset = feat_len * cid;
88
            if (BinaryOp::use_rhs) {
Zihao Ye's avatar
Zihao Ye committed
89
90
              accum_0 += BinaryOp::Call(ufeat + offset + fid_0, efeat + eid);
              accum_1 += BinaryOp::Call(ufeat + offset + fid_1, efeat + eid);
91
            } else {
Zihao Ye's avatar
Zihao Ye committed
92
93
              accum_0 += ufeat[offset + fid_0];
              accum_1 += ufeat[offset + fid_1];
94
95
96
97
98
99
            }
          }
        } else {
          for (int i = 0; i + left < high; ++i) {
            const Idx eid = left + i; 
            const Idx cid = __ldg(indices + eid); 
Zihao Ye's avatar
Zihao Ye committed
100
            const Idx offset = feat_len * cid;
101
            if (BinaryOp::use_rhs) {
Zihao Ye's avatar
Zihao Ye committed
102
103
              accum_0 += BinaryOp::Call(ufeat + offset + fid_0, efeat + eid);
              accum_1 += BinaryOp::Call(ufeat + offset + fid_1, efeat + eid);
104
            } else {
Zihao Ye's avatar
Zihao Ye committed
105
106
              accum_0 += ufeat[offset + fid_0];
              accum_1 += ufeat[offset + fid_1];
107
108
109
110
111
            }
          }
        }

        out[feat_len * rid + fid] = accum_0;
Zihao Ye's avatar
Zihao Ye committed
112
        if (fid + 32 < feat_len)
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
          out[feat_len * rid + fid + 32] = accum_1;
      }
    }
  }
}

template <typename Idx, typename DType,
          typename BinaryOp>
void GESpMMCsr(
    const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat,
    NDArray out, int64_t feat_len) {
  const Idx *indptr = csr.indptr.Ptr<Idx>();
  const Idx *indices = csr.indices.Ptr<Idx>();
  const DType *ufeat_data = ufeat.Ptr<DType>();
  const DType *efeat_data = efeat.Ptr<DType>();
  DType *out_data = out.Ptr<DType>();

131
  cudaStream_t stream = runtime::getCurrentCUDAStream();
132
133
134
135
136
137
138
139
140
141
  
  const int ntx = 32;
  const int nty = 32;
  const int nby = (feat_len + (ntx * 2) - 1) / (ntx * 2);
  const int nbx = (csr.num_rows + nty - 1) / nty;
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
  const int sh_mem_size = 0;

  CUDA_KERNEL_CALL((GESpMMKernel<Idx, DType, BinaryOp>),
142
      nblks, nthrs, sh_mem_size, stream,
143
144
145
146
147
148
149
150
151
152
153
      ufeat_data, efeat_data, out_data,
      indptr, indices,
      csr.num_rows, csr.num_cols,
      feat_len);
}

}  // namespace cuda
}  // namespace aten
}  // namespace dgl

#endif