spmm_kernel.cu 7.88 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <ATen/ATen.h>
rusty1s's avatar
rusty1s committed
2
#include <ATen/cuda/CUDAContext.h>
rusty1s's avatar
rusty1s committed
3
4
5

#include "compat.cuh"

rusty1s's avatar
rusty1s committed
6
#define THREADS 256
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#define FULL_MASK 0xffffffff

enum ReductionType { SUM, MEAN, MIN, MAX };

const std::map<std::string, ReductionType> reduce2REDUCE = {
    {"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};

#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...)                               \
  [&] {                                                                        \
    switch (reduce2REDUCE.at(reduce)) {                                        \
    case SUM: {                                                                \
      const ReductionType REDUCE = SUM;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MEAN: {                                                               \
      const ReductionType REDUCE = MEAN;                                       \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MIN: {                                                                \
      const ReductionType REDUCE = MIN;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case MAX: {                                                                \
      const ReductionType REDUCE = MAX;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    }                                                                          \
  }()

template <typename scalar_t, ReductionType REDUCE> struct Reducer {
  static inline __host__ __device__ scalar_t init() {
    if (REDUCE == MIN) {
      return std::numeric_limits<scalar_t>::max();
    } else if (REDUCE == MAX) {
      return std::numeric_limits<scalar_t>::lowest();
    } else {
      return (scalar_t)0;
    }
  }

  static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
                                                int64_t *arg, int64_t new_arg) {
    if (REDUCE == SUM || REDUCE == MEAN) {
      *val = *val + new_val;
    } else if ((REDUCE == MIN && new_val < *val) ||
               (REDUCE == MAX && new_val > *val)) {
      *val = new_val;
      *arg = new_arg;
    }
  }
rusty1s's avatar
rusty1s committed
58

rusty1s's avatar
rusty1s committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
  static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
                                               int64_t *arg_address,
                                               int64_t arg, int count) {
    if (REDUCE == SUM) {
      *address = val;
    } else if (REDUCE == MEAN) {
      *address = val / (scalar_t)max(count, 1);
    } else if (REDUCE == MIN || REDUCE == MAX) {
      if (count > 0) {
        *address = val;
        *arg_address = arg;
      } else {
        *address = (scalar_t)0;
      }
    }
  }
};
rusty1s's avatar
rusty1s committed
76

rusty1s's avatar
rusty1s committed
77
78
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code:  https://github.com/owensgroup/merge-spmm
rusty1s's avatar
rusty1s committed
79
template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
rusty1s's avatar
rusty1s committed
80
__global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
rusty1s's avatar
rusty1s committed
81
82
83
                            const scalar_t *value_data,
                            const scalar_t *mat_data, scalar_t *out_data,
                            int64_t *arg_out_data, int B, int M, int N, int K) {
rusty1s's avatar
rusty1s committed
84
85
86

  // We ignore blockIdx.y here, because threads
  // across `blockIdx.y` are treated equally.
rusty1s's avatar
rusty1s committed
87
88
  int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;

rusty1s's avatar
rusty1s committed
89
90
91
  int row = thread_idx >> 5;            // thread_idx / 32
  int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
  int batch_idx = row / M;
rusty1s's avatar
rusty1s committed
92
93
94
95

  // Compute the column index of `mat` in which the thread is operating.
  int mat_col_idx = lane_idx + (blockIdx.y << 5);

rusty1s's avatar
rusty1s committed
96
  // Compute the output index (row-major order).
rusty1s's avatar
rusty1s committed
97
  int out_idx = row * K + mat_col_idx;
rusty1s's avatar
rusty1s committed
98
99

  // Helper arrays for warp communication.
rusty1s's avatar
rusty1s committed
100
101
102
  int mat_row, mat_rows[32];
  scalar_t val, vals[HAS_VAL ? 32 : 1];
  int bla, blas[32];
rusty1s's avatar
rusty1s committed
103

rusty1s's avatar
rusty1s committed
104
  // Do not aggregate/write across the Y-axis (lane_idx < leftover).
rusty1s's avatar
rusty1s committed
105
106
  int leftover = K - (blockIdx.y << 5);

rusty1s's avatar
rusty1s committed
107
108
109
  if (row < B * M) {
    int row_start = __ldg(rowptr_data + (row % M));
    int row_end = __ldg(rowptr_data + (row % M) + 1);
rusty1s's avatar
rusty1s committed
110
111
    int col_idx = row_start + lane_idx;

rusty1s's avatar
rusty1s committed
112
113
114
115
    scalar_t result = Reducer<scalar_t, REDUCE>::init();
    int64_t arg;

    // Iterate over all `col` indices in parallel within a warp.
rusty1s's avatar
rusty1s committed
116
117
118
119
120
    for (int c = row_start; c < row_end; c += 32) {

      if (col_idx < row_end) {
        // Coalesced memory access into `col` and `val`.
        mat_row = __ldg(col_data + col_idx) * K;
rusty1s's avatar
rusty1s committed
121
122
123
        bla = col_idx;
        if (HAS_VAL)
          val = __ldg(value_data + col_idx);
rusty1s's avatar
rusty1s committed
124
      } else {
rusty1s's avatar
rusty1s committed
125
126
127
128
        mat_row = -1;
        bla = -1;
        if (HAS_VAL)
          val = (scalar_t)0;
rusty1s's avatar
rusty1s committed
129
130
      }
      col_idx += 32;
rusty1s's avatar
rusty1s committed
131

rusty1s's avatar
rusty1s committed
132
#pragma unroll
rusty1s's avatar
rusty1s committed
133
134
      for (int i = 0; i < 32; i++) {
        // Communication between all threads in a warp.
rusty1s's avatar
rusty1s committed
135
136
137
138
        mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i);
        blas[i] = __shfl_sync(FULL_MASK, bla, i);
        if (HAS_VAL)
          vals[i] = __shfl_sync(FULL_MASK, val, i);
rusty1s's avatar
rusty1s committed
139
140
      }

rusty1s's avatar
rusty1s committed
141
#pragma unroll
rusty1s's avatar
rusty1s committed
142
      for (int i = 0; i < 32; i++) {
rusty1s's avatar
rusty1s committed
143
        if (lane_idx < leftover && mat_rows[i] != -1) {
rusty1s's avatar
rusty1s committed
144
          // Coalesced memory access into `mat`.
rusty1s's avatar
rusty1s committed
145
146
147
148
          val = __ldg(mat_data + batch_idx * N * K + mat_rows[i] + mat_col_idx);
          if (HAS_VAL)
            val = vals[i] * val;
          Reducer<scalar_t, REDUCE>::update(&result, val, &arg, c + i);
rusty1s's avatar
rusty1s committed
149
150
        }
      }
rusty1s's avatar
rusty1s committed
151
    }
rusty1s's avatar
rusty1s committed
152

rusty1s's avatar
rusty1s committed
153
    if (lane_idx < leftover) {
rusty1s's avatar
rusty1s committed
154
155
156
157
      // Coalesced write into `out`.
      Reducer<scalar_t, REDUCE>::write(out_data + out_idx, result,
                                       arg_out_data + out_idx, arg,
                                       row_end - row_start);
rusty1s's avatar
rusty1s committed
158
159
160
161
    }
  }
}

rusty1s's avatar
rusty1s committed
162
163
164
std::tuple<at::Tensor, at::optional<at::Tensor>>
spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
          at::Tensor mat, std::string reduce) {
rusty1s's avatar
rusty1s committed
165

rusty1s's avatar
rusty1s committed
166
167
168
169
170
  AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
  AT_ASSERTM(col.dim() == 1, "Input mismatch");
  if (value_opt.has_value())
    AT_ASSERTM(value_opt.value().dim() == 1);
  AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
rusty1s's avatar
rusty1s committed
171

rusty1s's avatar
rusty1s committed
172
  mat = mat.contiguous();
rusty1s's avatar
rusty1s committed
173

rusty1s's avatar
rusty1s committed
174
175
176
  auto sizes = mat.sizes().vec();
  sizes[mat.dim() - 2] = rowptr.numel() - 1;
  auto out = at::empty(sizes, mat.options());
rusty1s's avatar
rusty1s committed
177

rusty1s's avatar
rusty1s committed
178
179
180
181
182
183
  at::optional<at::Tensor> arg_out = at::nullopt;
  int64_t *arg_out_data = nullptr;
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
    arg_out = at::full_like(out, col.numel(), rowptr.options());
    arg_out_data = arg_out.value().DATA_PTR<int64_t>();
  }
rusty1s's avatar
rusty1s committed
184
185
186
187

  auto rowptr_data = rowptr.DATA_PTR<int64_t>();
  auto col_data = col.DATA_PTR<int64_t>();

rusty1s's avatar
rusty1s committed
188
189
190
191
192
  auto M = rowptr.numel() - 1;
  auto N = mat.size(-2);
  auto K = mat.size(-1);
  auto B = mat.numel() / (N * K);
  auto BLOCKS = dim3((32 * B * M + THREADS - 1) / THREADS, (K + 31) / 32);
rusty1s's avatar
rusty1s committed
193
194
195
196
197
198

  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_kernel", [&] {
    auto mat_data = mat.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

rusty1s's avatar
rusty1s committed
199
200
201
202
203
204
205
206
207
208
209
210
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      if (value_opt.has_value()) {
        auto value_data = value_opt.value().DATA_PTR<scalar_t>();
        spmm_kernel<scalar_t, REDUCE, true><<<BLOCKS, THREADS, 0, stream>>>(
            rowptr_data, col_data, value_data, mat_data, out_data, arg_out_data,
            B, M, N, K);
      } else {
        spmm_kernel<scalar_t, REDUCE, false><<<BLOCKS, THREADS, 0, stream>>>(
            rowptr_data, col_data, nullptr, mat_data, out_data, arg_out_data, B,
            M, N, K);
      }
    });
rusty1s's avatar
rusty1s committed
211
212
213
214
  });

  return std::make_tuple(out, arg_out);
}