"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "10dc06c8d982f745e54d8d9daff6b258726b8172"
Commit 519306d3 authored by rusty1s's avatar rusty1s
Browse files

fast as fuck spmm kernel

parent 36d045fd
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "compat.cuh" #include "compat.cuh"
#define THREADS 32 * 16 #define Y_SIZE 32
#define THREADS 256
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU // Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code: https://github.com/owensgroup/merge-spmm // Code: https://github.com/owensgroup/merge-spmm
template <typename scalar_t, size_t Y_SIZE> template <typename scalar_t>
__global__ void __global__ void
spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data, spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
const scalar_t *val_data, const scalar_t *mat_data, const scalar_t *val_data, const scalar_t *mat_data,
scalar_t *out_data, size_t N, size_t M, size_t K) { scalar_t *out_data, size_t N, size_t K) {
// We ignore blockIdx.y here, because threads across blockIdx.y operate on the // We ignore blockIdx.y here, because threads across blockIdx.y operate on the
// same row. // same row.
...@@ -23,7 +25,7 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data, ...@@ -23,7 +25,7 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
// Compute the column index of `mat` in which the thread is operating. // Compute the column index of `mat` in which the thread is operating.
int mat_col_idx = lane_idx + (blockIdx.y << 5); int mat_col_idx = lane_idx + (blockIdx.y << 5);
// Compute the output index given in row-major order. // Compute the output index (row-major order).
int out_idx = row * K + lane_idx + (blockIdx.y << 5); int out_idx = row * K + lane_idx + (blockIdx.y << 5);
// Helper arrays for warp communication. // Helper arrays for warp communication.
...@@ -35,18 +37,30 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data, ...@@ -35,18 +37,30 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
if (row < N) { if (row < N) {
int row_start = __ldg(rowptr_data + row); int row_start = __ldg(rowptr_data + row);
int row_end = __ldg(rowptr_data + row + 1); int row_end = __ldg(rowptr_data + row + 1);
int col_idx = row_start + lane_idx;
int mat_row = -1;
scalar_t val = (scalar_t)0;
scalar_t sum = (scalar_t)0;
// Iterate over all col indices in parallel with 32 threads.
for (int c = row_start; c < row_end; c += 32) {
if (col_idx < row_end) {
// Coalesced memory access into `col` and `val`.
mat_row = __ldg(col_data + col_idx) * K;
val = __ldg(val_data + col_idx);
} else {
mat_row = 0;
val = (scalar_t)0;
}
col_idx += 32;
// Iterate over all col indices in parallel. #pragma unroll
for (int col_idx = row_start + lane_idx; col_idx < row_end; col_idx += 32) {
int mat_row = __ldg(col_data + col_idx) * K;
int val = __ldg(val_data + col_idx);
scalar_t sum = (scalar_t)0;
for (int i = 0; i < 32; i += Y_SIZE) { for (int i = 0; i < 32; i += Y_SIZE) {
#pragma unroll #pragma unroll
for (int j = 0; j < Y_SIZE; j++) { for (int j = 0; j < Y_SIZE; j++) {
// Warp communication with *all* threads (mask = 0xffffffff). // Communication between *all* threads in a warp.
// TODO: Compute real bit mask via `__ballot_sync()`.
mat_row_all[j] = __shfl_sync(0xffffffff, mat_row, i + j); mat_row_all[j] = __shfl_sync(0xffffffff, mat_row, i + j);
val_all[j] = __shfl_sync(0xffffffff, val, i + j); val_all[j] = __shfl_sync(0xffffffff, val, i + j);
} }
...@@ -58,34 +72,35 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data, ...@@ -58,34 +72,35 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
} }
} }
} }
if (lane_idx < leftover) { }
out_data[out_idx] = sum; if (lane_idx < leftover) {
} // Coalesced memory access into `out`.
out_data[out_idx] = sum;
} }
} }
} }
at::Tensor spmm_cuda(at::Tensor rowptr, at::Tensor col, at::Tensor val, at::Tensor spmm_cuda(at::Tensor rowptr, at::Tensor col, at::Tensor val,
at::Tensor mat) { at::Tensor mat) {
// TODO: Set device
auto N = rowptr.numel() - 1; auto N = rowptr.numel() - 1;
auto M = mat.size(0);
auto K = mat.size(1); auto K = mat.size(1);
auto out = at::empty({N, K}, mat.options()); auto out = at::empty({N, K}, mat.options());
auto rowptr_data = rowptr.DATA_PTR<int64_t>(); auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>(); auto col_data = col.DATA_PTR<int64_t>();
auto val_data = val.DATA_PTR<float>();
auto mat_data = mat.DATA_PTR<float>();
auto out_data = out.DATA_PTR<float>();
auto block_dim = dim3(THREADS); auto block_dim = dim3(THREADS);
auto grid_dim = dim3((N + THREADS - 1) / THREADS, (K + 32 - 1) / 32); auto grid_dim = dim3((32 * N + THREADS - 1) / THREADS, (K + 31) / 32);
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_kernel", [&] {
auto val_data = val.DATA_PTR<scalar_t>();
auto mat_data = mat.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
spmm_row_kernel<float, 32><<<grid_dim, block_dim, 0 /*, cuda_stream */>>>( spmm_row_kernel<scalar_t>
rowptr_data, col_data, val_data, mat_data, out_data, N, M, K); <<<grid_dim, block_dim, 0, at::cuda::getCurrentCUDAStream()>>>(
rowptr_data, col_data, val_data, mat_data, out_data, N, K);
});
return out; return out;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment