spmm_kernel.cu 3.27 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <ATen/ATen.h>
rusty1s's avatar
rusty1s committed
2
#include <ATen/cuda/CUDAContext.h>
rusty1s's avatar
rusty1s committed
3
4
5

#include "compat.cuh"

rusty1s's avatar
rusty1s committed
6
7
#define Y_SIZE 32
#define THREADS 256
rusty1s's avatar
rusty1s committed
8
9
10

// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code:  https://github.com/owensgroup/merge-spmm
rusty1s's avatar
rusty1s committed
11
template <typename scalar_t>
rusty1s's avatar
rusty1s committed
12
13
14
__global__ void
spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
                const scalar_t *val_data, const scalar_t *mat_data,
rusty1s's avatar
rusty1s committed
15
                scalar_t *out_data, size_t N, size_t K) {
rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22
23
24
25
26
27

  // We ignore blockIdx.y here, because threads across blockIdx.y operate on the
  // same row.
  int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;

  int warp_idx = thread_idx >> 5;       // thread_id / 32
  int lane_idx = thread_idx & (32 - 1); // thread_id % 32
  int row = warp_idx;                   // Each warp processes exactly one row.

  // Compute the column index of `mat` in which the thread is operating.
  int mat_col_idx = lane_idx + (blockIdx.y << 5);

rusty1s's avatar
rusty1s committed
28
  // Compute the output index (row-major order).
rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
36
37
38
39
  int out_idx = row * K + lane_idx + (blockIdx.y << 5);

  // Helper arrays for warp communication.
  int mat_row_all[Y_SIZE];
  scalar_t val_all[Y_SIZE];

  int leftover = K - (blockIdx.y << 5);

  if (row < N) {
    int row_start = __ldg(rowptr_data + row);
    int row_end = __ldg(rowptr_data + row + 1);
rusty1s's avatar
rusty1s committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    int col_idx = row_start + lane_idx;

    int mat_row = -1;
    scalar_t val = (scalar_t)0;
    scalar_t sum = (scalar_t)0;

    // Iterate over all col indices in parallel with 32 threads.
    for (int c = row_start; c < row_end; c += 32) {

      if (col_idx < row_end) {
        // Coalesced memory access into `col` and `val`.
        mat_row = __ldg(col_data + col_idx) * K;
        val = __ldg(val_data + col_idx);
      } else {
        mat_row = 0;
        val = (scalar_t)0;
      }
      col_idx += 32;
rusty1s's avatar
rusty1s committed
58

rusty1s's avatar
rusty1s committed
59
#pragma unroll
rusty1s's avatar
rusty1s committed
60
61
62
      for (int i = 0; i < 32; i += Y_SIZE) {
#pragma unroll
        for (int j = 0; j < Y_SIZE; j++) {
rusty1s's avatar
rusty1s committed
63
          // Communication between *all* threads in a warp.
rusty1s's avatar
rusty1s committed
64
65
66
67
68
69
70
71
72
73
74
          mat_row_all[j] = __shfl_sync(0xffffffff, mat_row, i + j);
          val_all[j] = __shfl_sync(0xffffffff, val, i + j);
        }
#pragma unroll
        for (int j = 0; j < Y_SIZE; j++) {
          if (lane_idx < leftover) {
            // Coalesced memory access into `mat`.
            sum += val_all[j] * __ldg(mat_data + mat_row_all[j] + mat_col_idx);
          }
        }
      }
rusty1s's avatar
rusty1s committed
75
76
77
78
    }
    if (lane_idx < leftover) {
      // Coalesced memory access into `out`.
      out_data[out_idx] = sum;
rusty1s's avatar
rusty1s committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    }
  }
}

at::Tensor spmm_cuda(at::Tensor rowptr, at::Tensor col, at::Tensor val,
                     at::Tensor mat) {
  auto N = rowptr.numel() - 1;
  auto K = mat.size(1);
  auto out = at::empty({N, K}, mat.options());

  auto rowptr_data = rowptr.DATA_PTR<int64_t>();
  auto col_data = col.DATA_PTR<int64_t>();

  auto block_dim = dim3(THREADS);
rusty1s's avatar
rusty1s committed
93
94
95
96
97
98
  auto grid_dim = dim3((32 * N + THREADS - 1) / THREADS, (K + 31) / 32);

  AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_kernel", [&] {
    auto val_data = val.DATA_PTR<scalar_t>();
    auto mat_data = mat.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();
rusty1s's avatar
rusty1s committed
99

rusty1s's avatar
rusty1s committed
100
101
102
103
    spmm_row_kernel<scalar_t>
        <<<grid_dim, block_dim, 0, at::cuda::getCurrentCUDAStream()>>>(
            rowptr_data, col_data, val_data, mat_data, out_data, N, K);
  });
rusty1s's avatar
rusty1s committed
104
105
106

  return out;
}