spmm_cuda.cu 7.55 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
#include "spmm_cuda.h"

#include <ATen/cuda/CUDAContext.h>

#include "reducer.cuh"
#include "utils.cuh"

#define THREADS 256
#define FULL_MASK 0xffffffff

// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
limm's avatar
limm committed
12

limm's avatar
limm committed
13
14
15
16
17
18
19
20
21
22
template <typename scalar_t, ReductionType REDUCE, bool HAS_VALUE>
__global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
                            const scalar_t *value_data,
                            const scalar_t *mat_data, scalar_t *out_data,
                            int64_t *arg_out_data, int B, int M, int N, int K) {

  // We ignore blockIdx.y here, because threads
  // across `blockIdx.y` are treated equally.
  int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;

limm's avatar
limm committed
23
24
  int row = thread_idx >> 6;            // thread_idx / 32
  int lane_idx = thread_idx & (64 - 1); // thread_idx % 32
limm's avatar
limm committed
25
26
27
  int batch_idx = row / M;

  // Compute the column index of `mat` in which the thread is operating.
limm's avatar
limm committed
28
  int mat_col_idx = lane_idx + (blockIdx.y << 6);
limm's avatar
limm committed
29
30
31
32
33

  // Compute the output index (row-major order).
  int out_idx = row * K + mat_col_idx;

  // Helper arrays for warp communication.
limm's avatar
limm committed
34
35
  int mat_row, mat_rows[64];
  scalar_t val, vals[HAS_VALUE ? 64 : 1];
limm's avatar
limm committed
36
37

  // Do not aggregate/write across the Y-axis (lane_idx < leftover).
limm's avatar
limm committed
38
  int leftover = K - (blockIdx.y << 6);
limm's avatar
limm committed
39
40
41
42
43
44
45
46
47
48

  if (batch_idx < B) {
    int row_start = __ldg(rowptr_data + (row % M));
    int row_end = __ldg(rowptr_data + (row % M) + 1);
    int col_idx = row_start + lane_idx;

    scalar_t result = Reducer<scalar_t, REDUCE>::init();
    int64_t arg;

    // Iterate over all `col` indices in parallel within a warp.
limm's avatar
limm committed
49
    for (int c = row_start; c < row_end; c += 64) {
limm's avatar
limm committed
50
51
52
53
54
55
56
57
58
59
60

      if (col_idx < row_end) {
        // Coalesced memory access into `col` and `val`.
        mat_row = __ldg(col_data + col_idx) * K;
        if (HAS_VALUE)
          val = __ldg(value_data + col_idx);
      } else {
        mat_row = -1;
        if (HAS_VALUE)
          val = (scalar_t)0;
      }
limm's avatar
limm committed
61
      col_idx += 64;
limm's avatar
limm committed
62
63

#pragma unroll
limm's avatar
limm committed
64
      for (int i = 0; i < 64; i++) {
limm's avatar
limm committed
65
        // Communication between all threads in a warp.
limm's avatar
limm committed
66
        mat_rows[i] = SHFL_SYNC(FULL_MASK, mat_row, i);
limm's avatar
limm committed
67
        if (HAS_VALUE)
limm's avatar
limm committed
68
          vals[i] = SHFL_SYNC(FULL_MASK, val, i);
limm's avatar
limm committed
69
70
71
      }

#pragma unroll
limm's avatar
limm committed
72
      for (int i = 0; i < 64; i++) {
limm's avatar
limm committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        if (lane_idx < leftover && mat_rows[i] != -1) {
          // Coalesced memory access into `mat`.
          val = __ldg(mat_data + batch_idx * N * K + mat_rows[i] + mat_col_idx);
          if (HAS_VALUE)
            val = vals[i] * val;
          Reducer<scalar_t, REDUCE>::update(&result, val, &arg, c + i);
        }
      }
    }

    if (lane_idx < leftover) {
      // Coalesced write into `out`.
      Reducer<scalar_t, REDUCE>::write(out_data + out_idx, result,
                                       arg_out_data + out_idx, arg,
                                       row_end - row_start);
    }
  }
}

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
          torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
          std::string reduce) {

  CHECK_CUDA(rowptr);
  CHECK_CUDA(col);
  if (optional_value.has_value())
    CHECK_CUDA(optional_value.value());
  CHECK_CUDA(mat);
  cudaSetDevice(rowptr.get_device());

  CHECK_INPUT(rowptr.dim() == 1);
  CHECK_INPUT(col.dim() == 1);
  if (optional_value.has_value()) {
    CHECK_INPUT(optional_value.value().dim() == 1);
    CHECK_INPUT(optional_value.value().size(0) == col.size(0));
  }
  CHECK_INPUT(mat.dim() >= 2);

  mat = mat.contiguous();

  auto sizes = mat.sizes().vec();
  sizes[mat.dim() - 2] = rowptr.numel() - 1;
  auto out = torch::empty(sizes, mat.options());

  torch::optional<torch::Tensor> arg_out = torch::nullopt;
  int64_t *arg_out_data = nullptr;
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
    arg_out = torch::full_like(out, col.numel(), rowptr.options());
    arg_out_data = arg_out.value().data_ptr<int64_t>();
  }

  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();

  auto M = rowptr.numel() - 1;
  auto N = mat.size(-2);
  auto K = mat.size(-1);
  auto B = mat.numel() / (N * K);
limm's avatar
limm committed
132
  auto BLOCKS = dim3((64 * B * M + THREADS - 1) / THREADS, (K + 63) / 64);
limm's avatar
limm committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

  auto stream = at::cuda::getCurrentCUDAStream();

  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
    auto mat_data = mat.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();

    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      if (optional_value.has_value()) {
        auto value_data = optional_value.value().data_ptr<scalar_t>();
        spmm_kernel<scalar_t, REDUCE, true><<<BLOCKS, THREADS, 0, stream>>>(
            rowptr_data, col_data, value_data, mat_data, out_data, arg_out_data,
            B, M, N, K);
      } else {
        spmm_kernel<scalar_t, REDUCE, false><<<BLOCKS, THREADS, 0, stream>>>(
            rowptr_data, col_data, nullptr, mat_data, out_data, arg_out_data, B,
            M, N, K);
      }
    });
  });

  return std::make_tuple(out, arg_out);
}

limm's avatar
limm committed
157

limm's avatar
limm committed
158
159
160
161
162
163
164
165
template <typename scalar_t, ReductionType REDUCE>
__global__ void
spmm_value_bw_kernel(const int64_t *row_data, const int64_t *rowptr_data,
                     const int64_t *col_data, const scalar_t *mat_data,
                     const scalar_t *grad_data, scalar_t *out_data, int B,
                     int M, int N, int E, int K) {
  int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;

limm's avatar
limm committed
166
167
  int index_idx = (thread_idx >> 6);    // thread_idx / 32
  int lane_idx = thread_idx & (64 - 1); // thread_idx % 32
limm's avatar
limm committed
168
169
170
171
172
173
174

  if (index_idx < E) {
    int row = __ldg(row_data + index_idx);
    int col = __ldg(col_data + index_idx);

    scalar_t val = (scalar_t)0;
    for (int b = 0; b < B; b++) {
limm's avatar
limm committed
175
      for (int k = lane_idx; k < K; k += 64) {
limm's avatar
limm committed
176
177
178
179
180
181
        val += mat_data[b * N * K + col * K + k] *
               grad_data[b * M * K + row * K + k];
      }
    }

#pragma unroll
limm's avatar
limm committed
182
183
    for (int i = 64 / 2; i > 0; i /= 2) { // Parallel reduction inside a warp.
      val += SHFL_DOWN_SYNC(FULL_MASK, val, i);
limm's avatar
limm committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    }

    if (lane_idx == 0) {
      if (REDUCE == MEAN) {
        int row_start = __ldg(rowptr_data + row);
        int row_end = __ldg(rowptr_data + row + 1);
        val /= (scalar_t)max(row_end - row_start, 1);
      }
      out_data[index_idx] = val;
    }
  }
}

torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
                                 torch::Tensor col, torch::Tensor mat,
                                 torch::Tensor grad, std::string reduce) {
  CHECK_CUDA(row);
  CHECK_CUDA(rowptr);
  CHECK_CUDA(col);
  CHECK_CUDA(mat);
  CHECK_CUDA(grad);
  cudaSetDevice(row.get_device());

  mat = mat.contiguous();
  grad = grad.contiguous();

  auto M = grad.size(-2);
  auto N = mat.size(-2);
  auto E = row.numel();
  auto K = mat.size(-1);
  auto B = mat.numel() / (N * K);
limm's avatar
limm committed
215
  auto BLOCKS = dim3((E * 64 + THREADS - 1) / THREADS);
limm's avatar
limm committed
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238

  auto out = torch::zeros({row.numel()}, grad.options());

  auto row_data = row.data_ptr<int64_t>();
  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();

  auto stream = at::cuda::getCurrentCUDAStream();

  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
    auto mat_data = mat.data_ptr<scalar_t>();
    auto grad_data = grad.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();

    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      spmm_value_bw_kernel<scalar_t, REDUCE><<<BLOCKS, THREADS, 0, stream>>>(
          row_data, rowptr_data, col_data, mat_data, grad_data, out_data, B, M,
          N, E, K);
    });
  });

  return out;
}