spmm_kernel.cu 10.5 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <ATen/ATen.h>
rusty1s's avatar
rusty1s committed
2
#include <ATen/cuda/CUDAContext.h>
rusty1s's avatar
rusty1s committed
3
4
5

#include "compat.cuh"

rusty1s's avatar
rusty1s committed
6
#define THREADS 256
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#define FULL_MASK 0xffffffff

enum ReductionType { SUM, MEAN, MIN, MAX };

const std::map<std::string, ReductionType> reduce2REDUCE = {
    {"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};

#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...)                               \
  [&] {                                                                        \
    switch (reduce2REDUCE.at(reduce)) {                                        \
    case SUM: {                                                                \
      const ReductionType REDUCE = SUM;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MEAN: {                                                               \
      const ReductionType REDUCE = MEAN;                                       \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MIN: {                                                                \
      const ReductionType REDUCE = MIN;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MAX: {                                                                \
      const ReductionType REDUCE = MAX;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    }                                                                          \
  }()

template <typename scalar_t, ReductionType REDUCE> struct Reducer {
  static inline __host__ __device__ scalar_t init() {
    if (REDUCE == MIN) {
      return std::numeric_limits<scalar_t>::max();
    } else if (REDUCE == MAX) {
      return std::numeric_limits<scalar_t>::lowest();
    } else {
      return (scalar_t)0;
    }
  }

  static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
                                                int64_t *arg, int64_t new_arg) {
    if (REDUCE == SUM || REDUCE == MEAN) {
      *val = *val + new_val;
    } else if ((REDUCE == MIN && new_val < *val) ||
               (REDUCE == MAX && new_val > *val)) {
      *val = new_val;
      *arg = new_arg;
    }
  }
rusty1s's avatar
rusty1s committed
58

rusty1s's avatar
rusty1s committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
  static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
                                               int64_t *arg_address,
                                               int64_t arg, int count) {
    if (REDUCE == SUM) {
      *address = val;
    } else if (REDUCE == MEAN) {
      *address = val / (scalar_t)max(count, 1);
    } else if (REDUCE == MIN || REDUCE == MAX) {
      if (count > 0) {
        *address = val;
        *arg_address = arg;
      } else {
        *address = (scalar_t)0;
      }
    }
  }
};
rusty1s's avatar
rusty1s committed
76

rusty1s's avatar
rusty1s committed
77
78
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code:  https://github.com/owensgroup/merge-spmm
rusty1s's avatar
rusty1s committed
79
template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
rusty1s's avatar
rusty1s committed
80
__global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
rusty1s's avatar
rusty1s committed
81
82
83
                            const scalar_t *value_data,
                            const scalar_t *mat_data, scalar_t *out_data,
                            int64_t *arg_out_data, int B, int M, int N, int K) {
rusty1s's avatar
rusty1s committed
84
85
86

  // We ignore blockIdx.y here, because threads
  // across `blockIdx.y` are treated equally.
rusty1s's avatar
rusty1s committed
87
88
  int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;

rusty1s's avatar
rusty1s committed
89
90
91
  int row = thread_idx >> 5;            // thread_idx / 32
  int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
  int batch_idx = row / M;
rusty1s's avatar
rusty1s committed
92
93
94
95

  // Compute the column index of `mat` in which the thread is operating.
  int mat_col_idx = lane_idx + (blockIdx.y << 5);

rusty1s's avatar
rusty1s committed
96
  // Compute the output index (row-major order).
rusty1s's avatar
rusty1s committed
97
  int out_idx = row * K + mat_col_idx;
rusty1s's avatar
rusty1s committed
98
99

  // Helper arrays for warp communication.
rusty1s's avatar
rusty1s committed
100
101
  int mat_row, mat_rows[32];
  scalar_t val, vals[HAS_VAL ? 32 : 1];
rusty1s's avatar
rusty1s committed
102

rusty1s's avatar
rusty1s committed
103
  // Do not aggregate/write across the Y-axis (lane_idx < leftover).
rusty1s's avatar
rusty1s committed
104
105
  int leftover = K - (blockIdx.y << 5);

rusty1s's avatar
rusty1s committed
106
  if (batch_idx < B) {
rusty1s's avatar
rusty1s committed
107
108
    int row_start = __ldg(rowptr_data + (row % M));
    int row_end = __ldg(rowptr_data + (row % M) + 1);
rusty1s's avatar
rusty1s committed
109
110
    int col_idx = row_start + lane_idx;

rusty1s's avatar
rusty1s committed
111
112
113
114
    scalar_t result = Reducer<scalar_t, REDUCE>::init();
    int64_t arg;

    // Iterate over all `col` indices in parallel within a warp.
rusty1s's avatar
rusty1s committed
115
116
117
118
119
    for (int c = row_start; c < row_end; c += 32) {

      if (col_idx < row_end) {
        // Coalesced memory access into `col` and `val`.
        mat_row = __ldg(col_data + col_idx) * K;
rusty1s's avatar
rusty1s committed
120
121
        if (HAS_VAL)
          val = __ldg(value_data + col_idx);
rusty1s's avatar
rusty1s committed
122
      } else {
rusty1s's avatar
rusty1s committed
123
124
125
        mat_row = -1;
        if (HAS_VAL)
          val = (scalar_t)0;
rusty1s's avatar
rusty1s committed
126
127
      }
      col_idx += 32;
rusty1s's avatar
rusty1s committed
128

rusty1s's avatar
rusty1s committed
129
#pragma unroll
rusty1s's avatar
rusty1s committed
130
131
      for (int i = 0; i < 32; i++) {
        // Communication between all threads in a warp.
rusty1s's avatar
rusty1s committed
132
133
134
        mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i);
        if (HAS_VAL)
          vals[i] = __shfl_sync(FULL_MASK, val, i);
rusty1s's avatar
rusty1s committed
135
136
      }

rusty1s's avatar
rusty1s committed
137
#pragma unroll
rusty1s's avatar
rusty1s committed
138
      for (int i = 0; i < 32; i++) {
rusty1s's avatar
rusty1s committed
139
        if (lane_idx < leftover && mat_rows[i] != -1) {
rusty1s's avatar
rusty1s committed
140
          // Coalesced memory access into `mat`.
rusty1s's avatar
rusty1s committed
141
142
143
144
          val = __ldg(mat_data + batch_idx * N * K + mat_rows[i] + mat_col_idx);
          if (HAS_VAL)
            val = vals[i] * val;
          Reducer<scalar_t, REDUCE>::update(&result, val, &arg, c + i);
rusty1s's avatar
rusty1s committed
145
146
        }
      }
rusty1s's avatar
rusty1s committed
147
    }
rusty1s's avatar
rusty1s committed
148

rusty1s's avatar
rusty1s committed
149
    if (lane_idx < leftover) {
rusty1s's avatar
rusty1s committed
150
151
152
153
      // Coalesced write into `out`.
      Reducer<scalar_t, REDUCE>::write(out_data + out_idx, result,
                                       arg_out_data + out_idx, arg,
                                       row_end - row_start);
rusty1s's avatar
rusty1s committed
154
155
156
157
    }
  }
}

rusty1s's avatar
rusty1s committed
158
159
160
std::tuple<at::Tensor, at::optional<at::Tensor>>
spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
          at::Tensor mat, std::string reduce) {
rusty1s's avatar
rusty1s committed
161

rusty1s's avatar
rusty1s committed
162
163
164
165
166
  AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
  AT_ASSERTM(col.dim() == 1, "Input mismatch");
  if (value_opt.has_value())
    AT_ASSERTM(value_opt.value().dim() == 1);
  AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
rusty1s's avatar
rusty1s committed
167

rusty1s's avatar
rusty1s committed
168
  mat = mat.contiguous();
rusty1s's avatar
rusty1s committed
169

rusty1s's avatar
rusty1s committed
170
171
172
  auto sizes = mat.sizes().vec();
  sizes[mat.dim() - 2] = rowptr.numel() - 1;
  auto out = at::empty(sizes, mat.options());
rusty1s's avatar
rusty1s committed
173

rusty1s's avatar
rusty1s committed
174
175
176
177
178
179
  at::optional<at::Tensor> arg_out = at::nullopt;
  int64_t *arg_out_data = nullptr;
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
    arg_out = at::full_like(out, col.numel(), rowptr.options());
    arg_out_data = arg_out.value().DATA_PTR<int64_t>();
  }
rusty1s's avatar
rusty1s committed
180

rusty1s's avatar
rusty1s committed
181
182
183
184
185
  auto M = rowptr.numel() - 1;
  auto N = mat.size(-2);
  auto K = mat.size(-1);
  auto B = mat.numel() / (N * K);
  auto BLOCKS = dim3((32 * B * M + THREADS - 1) / THREADS, (K + 31) / 32);
rusty1s's avatar
rusty1s committed
186
187
188

  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_kernel", [&] {
rusty1s's avatar
rusty1s committed
189
190
    auto rowptr_data = rowptr.DATA_PTR<int64_t>();
    auto col_data = col.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
191
192
193
    auto mat_data = mat.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

rusty1s's avatar
rusty1s committed
194
195
196
197
198
199
200
201
202
203
204
205
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      if (value_opt.has_value()) {
        auto value_data = value_opt.value().DATA_PTR<scalar_t>();
        spmm_kernel<scalar_t, REDUCE, true><<<BLOCKS, THREADS, 0, stream>>>(
            rowptr_data, col_data, value_data, mat_data, out_data, arg_out_data,
            B, M, N, K);
      } else {
        spmm_kernel<scalar_t, REDUCE, false><<<BLOCKS, THREADS, 0, stream>>>(
            rowptr_data, col_data, nullptr, mat_data, out_data, arg_out_data, B,
            M, N, K);
      }
    });
rusty1s's avatar
rusty1s committed
206
207
208
209
  });

  return std::make_tuple(out, arg_out);
}
rusty1s's avatar
rusty1s committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290

template <typename scalar_t, ReductionType REDUCE>
__global__ void
spmm_val_bw_kernel(const int64_t *index_data, const int64_t *rowptr_data,
                   const scalar_t *mat_data, const scalar_t *grad_data,
                   scalar_t *out_data, int B, int M, int N, int E, int K) {
  int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;

  int index_idx = (thread_idx >> 5);    // thread_idx / 32
  int lane_idx = thread_idx & (32 - 1); // thread_idx % 32

  if (index_idx < E) {
    int row = __ldg(index_data + index_idx);
    int col = __ldg(index_data + E + index_idx);

    scalar_t val = (scalar_t)0;
    for (int b = 0; b < B; b++) {
      for (int k = lane_idx; k < K; k += 32) {
        val += mat_data[b * N * K + col * K + k] *
               grad_data[b * M * K + row * K + k];
      }
    }

#pragma unroll
    for (int i = 32 / 2; i > 0; i /= 2) { // Parallel reduction inside a warp.
      val += __shfl_down_sync(FULL_MASK, val, i);
    }

    if (lane_idx == 0) {
      if (REDUCE == MEAN) {
        int row_start = __ldg(rowptr_data + row);
        int row_end = __ldg(rowptr_data + row + 1);
        val /= (scalar_t)max(row_end - row_start, 1);
      }
      out_data[index_idx] = val;
    }
  }
}

at::Tensor spmm_val_bw_cuda(at::Tensor index, at::Tensor rowptr, at::Tensor mat,
                            at::Tensor grad, std::string reduce) {

  AT_ASSERTM(index.dim() == 2, "Input mismatch");
  AT_ASSERTM(index.size(0) == 2, "Input mismatch");
  AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
  AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
  AT_ASSERTM(mat.dim() == grad.dim(), "Input mismatch");
  AT_ASSERTM(reduce2REDUCE.at(reduce) == SUM ||
                 reduce2REDUCE.at(reduce) == MEAN,
             "Reduce operation not supported");

  index = index.contiguous();
  mat = mat.contiguous();
  grad = grad.contiguous();

  auto M = grad.size(-2);
  auto N = mat.size(-2);
  auto E = index.size(1);
  auto K = mat.size(-1);
  auto B = mat.numel() / (N * K);
  auto BLOCKS = dim3((E * 32 + THREADS - 1) / THREADS);

  auto out = at::empty(index.size(1), grad.options());

  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw_kernel", [&] {
    auto index_data = index.DATA_PTR<int64_t>();
    auto rowptr_data = rowptr.DATA_PTR<int64_t>();
    auto mat_data = mat.DATA_PTR<scalar_t>();
    auto grad_data = grad.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      spmm_val_bw_kernel<scalar_t, REDUCE>
          <<<BLOCKS, THREADS, 0, stream>>>(index_data, rowptr_data, mat_data,
                                           grad_data, out_data, B, M, N, E, K);
    });
  });

  return out;
}