ge_spmm.cuh 4.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
/*!
 * Copyright (c) 2020 by Contributors
 * \file array/cuda/ge_spmm.cuh
 * \brief GE-SpMM CUDA kernel function header.
 */
#ifndef DGL_ARRAY_CUDA_GE_SPMM_CUH_
#define DGL_ARRAY_CUDA_GE_SPMM_CUH_

#include "macro.cuh"
#include "atomic.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"

namespace dgl {

using namespace cuda;

namespace aten {
namespace cuda {

/*! 
 * \brief CUDA kernel of GE-SpMM on Csr.
 * \note GE-SpMM: https://arxiv.org/pdf/2007.03179.pdf
 *       The grid dimension x and y are reordered for better performance.
 */
template <typename Idx, typename DType,
          typename BinaryOp>
__global__ void GESpMMKernel(
    const DType* __restrict__ ufeat,
    const DType* __restrict__ efeat,
    DType* __restrict__ out,
    const Idx* __restrict__ indptr,
    const Idx* __restrict__ indices,
    const int64_t num_rows, const int64_t num_cols,
    const int64_t feat_len) {
  const Idx rid = blockIdx.x * blockDim.y + threadIdx.y;  // over vertices dimension
  const Idx fid = (blockIdx.y * 64) + threadIdx.x;        // over feature dimension

  if (rid < num_rows && fid < feat_len) {
    const Idx low = __ldg(indptr + rid), high = __ldg(indptr + rid + 1);
    DType accum_0 = 0.,
          accum_1 = 0.;

    if (blockIdx.y != gridDim.y - 1) {
      for (Idx left = low; left < high; left += 32) {
        if (left + 32 <= high) {
#pragma unroll
          for (Idx i = 0; i < 32; ++i) {
            const Idx eid = left + i; 
            const Idx cid = __ldg(indices + eid);
            const Idx offset = feat_len * cid + fid;
            if (BinaryOp::use_rhs) {
              accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid);
              accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid);
            } else {
              accum_0 += ufeat[offset];
              accum_1 += ufeat[offset + 32];
            }
          }
        } else {
          for (Idx i = 0; left + i < high; ++i) {
            const Idx eid = left + i; 
            const Idx cid = __ldg(indices + eid);
            const Idx offset = feat_len * cid + fid;
            if (BinaryOp::use_rhs) {
              accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid);
              accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid);
            } else {
              accum_0 += ufeat[offset];
              accum_1 += ufeat[offset + 32];
            }
          }
        }

        out[feat_len * rid + fid] = accum_0;
        out[feat_len * rid + fid + 32] = accum_1;
      }
    } else {
      bool right_inbound = fid + 32 < feat_len;
      for (int left = low; left < high; left += 32) {
        if (left + 32 <= high) {
#pragma unroll
          for (int i = 0; i < 32; ++i) {
            const Idx eid = left + i; 
            const Idx cid = __ldg(indices + eid); 
            const Idx offset = feat_len * cid + fid;
            if (BinaryOp::use_rhs) {
              accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid);
              accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid);
            } else {
              accum_0 += ufeat[offset];
              accum_1 += ufeat[offset + 32];
            }
          }
        } else {
          for (int i = 0; i + left < high; ++i) {
            const Idx eid = left + i; 
            const Idx cid = __ldg(indices + eid); 
            const Idx offset = feat_len * cid + fid;
            if (BinaryOp::use_rhs) {
              accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid);
              accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid);
            } else {
              accum_0 += ufeat[offset];
              accum_1 += ufeat[offset + 32];
            }
          }
        }

        out[feat_len * rid + fid] = accum_0;
        if (right_inbound)
          out[feat_len * rid + fid + 32] = accum_1;
      }
    }
  }
}

template <typename Idx, typename DType,
          typename BinaryOp>
void GESpMMCsr(
    const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat,
    NDArray out, int64_t feat_len) {
  const Idx *indptr = csr.indptr.Ptr<Idx>();
  const Idx *indices = csr.indices.Ptr<Idx>();
  const DType *ufeat_data = ufeat.Ptr<DType>();
  const DType *efeat_data = efeat.Ptr<DType>();
  DType *out_data = out.Ptr<DType>();

  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  
  const int ntx = 32;
  const int nty = 32;
  const int nby = (feat_len + (ntx * 2) - 1) / (ntx * 2);
  const int nbx = (csr.num_rows + nty - 1) / nty;
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
  const int sh_mem_size = 0;

  CUDA_KERNEL_CALL((GESpMMKernel<Idx, DType, BinaryOp>),
      nblks, nthrs, sh_mem_size, thr_entry->stream,
      ufeat_data, efeat_data, out_data,
      indptr, indices,
      csr.num_rows, csr.num_cols,
      feat_len);
}

}  // namespace cuda
}  // namespace aten
}  // namespace dgl

#endif