spmm_kernel.cu 10.1 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <ATen/ATen.h>
rusty1s's avatar
rusty1s committed
2
#include <ATen/cuda/CUDAContext.h>
rusty1s's avatar
rusty1s committed
3
4
5

#include "compat.cuh"

rusty1s's avatar
rusty1s committed
6
#define THREADS 256
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#define FULL_MASK 0xffffffff

enum ReductionType { SUM, MEAN, MIN, MAX };

const std::map<std::string, ReductionType> reduce2REDUCE = {
    {"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};

#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...)                               \
  [&] {                                                                        \
    switch (reduce2REDUCE.at(reduce)) {                                        \
    case SUM: {                                                                \
      const ReductionType REDUCE = SUM;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MEAN: {                                                               \
      const ReductionType REDUCE = MEAN;                                       \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MIN: {                                                                \
      const ReductionType REDUCE = MIN;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MAX: {                                                                \
      const ReductionType REDUCE = MAX;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    }                                                                          \
  }()

template <typename scalar_t, ReductionType REDUCE> struct Reducer {
  static inline __host__ __device__ scalar_t init() {
    if (REDUCE == MIN) {
      return std::numeric_limits<scalar_t>::max();
    } else if (REDUCE == MAX) {
      return std::numeric_limits<scalar_t>::lowest();
    } else {
      return (scalar_t)0;
    }
  }

  static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
                                                int64_t *arg, int64_t new_arg) {
    if (REDUCE == SUM || REDUCE == MEAN) {
      *val = *val + new_val;
    } else if ((REDUCE == MIN && new_val < *val) ||
               (REDUCE == MAX && new_val > *val)) {
      *val = new_val;
      *arg = new_arg;
    }
  }
rusty1s's avatar
rusty1s committed
58

rusty1s's avatar
rusty1s committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
  static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
                                               int64_t *arg_address,
                                               int64_t arg, int count) {
    if (REDUCE == SUM) {
      *address = val;
    } else if (REDUCE == MEAN) {
      *address = val / (scalar_t)max(count, 1);
    } else if (REDUCE == MIN || REDUCE == MAX) {
      if (count > 0) {
        *address = val;
        *arg_address = arg;
      } else {
        *address = (scalar_t)0;
      }
    }
  }
};
rusty1s's avatar
rusty1s committed
76

rusty1s's avatar
rusty1s committed
77
78
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code:  https://github.com/owensgroup/merge-spmm
rusty1s's avatar
rusty1s committed
79
template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
rusty1s's avatar
rusty1s committed
80
__global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
rusty1s's avatar
rusty1s committed
81
82
83
                            const scalar_t *value_data,
                            const scalar_t *mat_data, scalar_t *out_data,
                            int64_t *arg_out_data, int B, int M, int N, int K) {
rusty1s's avatar
rusty1s committed
84
85
86

  // We ignore blockIdx.y here, because threads
  // across `blockIdx.y` are treated equally.
rusty1s's avatar
rusty1s committed
87
88
  int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;

rusty1s's avatar
rusty1s committed
89
90
91
  int row = thread_idx >> 5;            // thread_idx / 32
  int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
  int batch_idx = row / M;
rusty1s's avatar
rusty1s committed
92
93
94
95

  // Compute the column index of `mat` in which the thread is operating.
  int mat_col_idx = lane_idx + (blockIdx.y << 5);

rusty1s's avatar
rusty1s committed
96
  // Compute the output index (row-major order).
rusty1s's avatar
rusty1s committed
97
  int out_idx = row * K + mat_col_idx;
rusty1s's avatar
rusty1s committed
98
99

  // Helper arrays for warp communication.
rusty1s's avatar
rusty1s committed
100
101
  int mat_row, mat_rows[32];
  scalar_t val, vals[HAS_VAL ? 32 : 1];
rusty1s's avatar
rusty1s committed
102

rusty1s's avatar
rusty1s committed
103
  // Do not aggregate/write across the Y-axis (lane_idx < leftover).
rusty1s's avatar
rusty1s committed
104
105
  int leftover = K - (blockIdx.y << 5);

rusty1s's avatar
rusty1s committed
106
  if (batch_idx < B) {
rusty1s's avatar
rusty1s committed
107
108
    int row_start = __ldg(rowptr_data + (row % M));
    int row_end = __ldg(rowptr_data + (row % M) + 1);
rusty1s's avatar
rusty1s committed
109
110
    int col_idx = row_start + lane_idx;

rusty1s's avatar
rusty1s committed
111
112
113
114
    scalar_t result = Reducer<scalar_t, REDUCE>::init();
    int64_t arg;

    // Iterate over all `col` indices in parallel within a warp.
rusty1s's avatar
rusty1s committed
115
116
117
118
119
    for (int c = row_start; c < row_end; c += 32) {

      if (col_idx < row_end) {
        // Coalesced memory access into `col` and `val`.
        mat_row = __ldg(col_data + col_idx) * K;
rusty1s's avatar
rusty1s committed
120
121
        if (HAS_VAL)
          val = __ldg(value_data + col_idx);
rusty1s's avatar
rusty1s committed
122
      } else {
rusty1s's avatar
rusty1s committed
123
124
125
        mat_row = -1;
        if (HAS_VAL)
          val = (scalar_t)0;
rusty1s's avatar
rusty1s committed
126
127
      }
      col_idx += 32;
rusty1s's avatar
rusty1s committed
128

rusty1s's avatar
rusty1s committed
129
#pragma unroll
rusty1s's avatar
rusty1s committed
130
131
      for (int i = 0; i < 32; i++) {
        // Communication between all threads in a warp.
rusty1s's avatar
rusty1s committed
132
133
134
        mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i);
        if (HAS_VAL)
          vals[i] = __shfl_sync(FULL_MASK, val, i);
rusty1s's avatar
rusty1s committed
135
136
      }

rusty1s's avatar
rusty1s committed
137
#pragma unroll
rusty1s's avatar
rusty1s committed
138
      for (int i = 0; i < 32; i++) {
rusty1s's avatar
rusty1s committed
139
        if (lane_idx < leftover && mat_rows[i] != -1) {
rusty1s's avatar
rusty1s committed
140
          // Coalesced memory access into `mat`.
rusty1s's avatar
rusty1s committed
141
142
143
144
          val = __ldg(mat_data + batch_idx * N * K + mat_rows[i] + mat_col_idx);
          if (HAS_VAL)
            val = vals[i] * val;
          Reducer<scalar_t, REDUCE>::update(&result, val, &arg, c + i);
rusty1s's avatar
rusty1s committed
145
146
        }
      }
rusty1s's avatar
rusty1s committed
147
    }
rusty1s's avatar
rusty1s committed
148

rusty1s's avatar
rusty1s committed
149
    if (lane_idx < leftover) {
rusty1s's avatar
rusty1s committed
150
151
152
153
      // Coalesced write into `out`.
      Reducer<scalar_t, REDUCE>::write(out_data + out_idx, result,
                                       arg_out_data + out_idx, arg,
                                       row_end - row_start);
rusty1s's avatar
rusty1s committed
154
155
156
157
    }
  }
}

rusty1s's avatar
rusty1s committed
158
159
160
std::tuple<at::Tensor, at::optional<at::Tensor>>
spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
          at::Tensor mat, std::string reduce) {
rusty1s's avatar
rusty1s committed
161

rusty1s's avatar
rusty1s committed
162
163
164
165
166
  AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
  AT_ASSERTM(col.dim() == 1, "Input mismatch");
  if (value_opt.has_value())
    AT_ASSERTM(value_opt.value().dim() == 1);
  AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
rusty1s's avatar
rusty1s committed
167

rusty1s's avatar
rusty1s committed
168
  mat = mat.contiguous();
rusty1s's avatar
rusty1s committed
169

rusty1s's avatar
rusty1s committed
170
171
172
  auto sizes = mat.sizes().vec();
  sizes[mat.dim() - 2] = rowptr.numel() - 1;
  auto out = at::empty(sizes, mat.options());
rusty1s's avatar
rusty1s committed
173

rusty1s's avatar
rusty1s committed
174
175
176
177
178
179
  at::optional<at::Tensor> arg_out = at::nullopt;
  int64_t *arg_out_data = nullptr;
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
    arg_out = at::full_like(out, col.numel(), rowptr.options());
    arg_out_data = arg_out.value().DATA_PTR<int64_t>();
  }
rusty1s's avatar
rusty1s committed
180

rusty1s's avatar
rusty1s committed
181
182
183
184
185
  auto M = rowptr.numel() - 1;
  auto N = mat.size(-2);
  auto K = mat.size(-1);
  auto B = mat.numel() / (N * K);
  auto BLOCKS = dim3((32 * B * M + THREADS - 1) / THREADS, (K + 31) / 32);
rusty1s's avatar
rusty1s committed
186
187
188

  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_kernel", [&] {
rusty1s's avatar
rusty1s committed
189
190
    auto rowptr_data = rowptr.DATA_PTR<int64_t>();
    auto col_data = col.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
191
192
193
    auto mat_data = mat.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

rusty1s's avatar
rusty1s committed
194
195
196
197
198
199
200
201
202
203
204
205
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      if (value_opt.has_value()) {
        auto value_data = value_opt.value().DATA_PTR<scalar_t>();
        spmm_kernel<scalar_t, REDUCE, true><<<BLOCKS, THREADS, 0, stream>>>(
            rowptr_data, col_data, value_data, mat_data, out_data, arg_out_data,
            B, M, N, K);
      } else {
        spmm_kernel<scalar_t, REDUCE, false><<<BLOCKS, THREADS, 0, stream>>>(
            rowptr_data, col_data, nullptr, mat_data, out_data, arg_out_data, B,
            M, N, K);
      }
    });
rusty1s's avatar
rusty1s committed
206
207
208
209
  });

  return std::make_tuple(out, arg_out);
}
rusty1s's avatar
rusty1s committed
210
211
212

template <typename scalar_t, ReductionType REDUCE>
__global__ void
rusty1s's avatar
rusty1s committed
213
214
215
216
spmm_val_bw_kernel(const int64_t *row_data, const int64_t *rowptr_data,
                   const int64_t *col_data, const scalar_t *mat_data,
                   const scalar_t *grad_data, scalar_t *out_data, int B, int M,
                   int N, int E, int K) {
rusty1s's avatar
rusty1s committed
217
218
219
220
221
222
  int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;

  int index_idx = (thread_idx >> 5);    // thread_idx / 32
  int lane_idx = thread_idx & (32 - 1); // thread_idx % 32

  if (index_idx < E) {
rusty1s's avatar
rusty1s committed
223
224
    int row = __ldg(row_data + index_idx);
    int col = __ldg(col_data + index_idx);
rusty1s's avatar
rusty1s committed
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249

    scalar_t val = (scalar_t)0;
    for (int b = 0; b < B; b++) {
      for (int k = lane_idx; k < K; k += 32) {
        val += mat_data[b * N * K + col * K + k] *
               grad_data[b * M * K + row * K + k];
      }
    }

#pragma unroll
    for (int i = 32 / 2; i > 0; i /= 2) { // Parallel reduction inside a warp.
      val += __shfl_down_sync(FULL_MASK, val, i);
    }

    if (lane_idx == 0) {
      if (REDUCE == MEAN) {
        int row_start = __ldg(rowptr_data + row);
        int row_end = __ldg(rowptr_data + row + 1);
        val /= (scalar_t)max(row_end - row_start, 1);
      }
      out_data[index_idx] = val;
    }
  }
}

rusty1s's avatar
rusty1s committed
250
251
252
at::Tensor spmm_val_bw_cuda(at::Tensor row, at::Tensor rowptr, at::Tensor col,
                            at::Tensor mat, at::Tensor grad,
                            std::string reduce) {
rusty1s's avatar
rusty1s committed
253
254
255
256
257
258

  mat = mat.contiguous();
  grad = grad.contiguous();

  auto M = grad.size(-2);
  auto N = mat.size(-2);
rusty1s's avatar
rusty1s committed
259
  auto E = row.numel();
rusty1s's avatar
rusty1s committed
260
261
262
263
  auto K = mat.size(-1);
  auto B = mat.numel() / (N * K);
  auto BLOCKS = dim3((E * 32 + THREADS - 1) / THREADS);

rusty1s's avatar
rusty1s committed
264
  auto out = at::zeros(row.numel(), grad.options());
rusty1s's avatar
rusty1s committed
265
266
267

  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw_kernel", [&] {
rusty1s's avatar
rusty1s committed
268
    auto row_data = row.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
269
    auto rowptr_data = rowptr.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
270
    auto col_data = col.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
271
272
273
274
275
    auto mat_data = mat.DATA_PTR<scalar_t>();
    auto grad_data = grad.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
rusty1s's avatar
rusty1s committed
276
277
278
      spmm_val_bw_kernel<scalar_t, REDUCE><<<BLOCKS, THREADS, 0, stream>>>(
          row_data, rowptr_data, col_data, mat_data, grad_data, out_data, B, M,
          N, E, K);
rusty1s's avatar
rusty1s committed
279
280
281
282
283
    });
  });

  return out;
}