Commit ac5d7a78 authored by rusty1s's avatar rusty1s
Browse files

all kernels

parent 8b07240b
......@@ -2,18 +2,37 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor spmm_cuda(at::Tensor rowptr, at::Tensor col, at::Tensor val,
at::Tensor mat);
at::Tensor spmm_cuda(at::Tensor rowptr, at::Tensor col,
at::optional<at::Tensor> val, at::Tensor mat,
std::string reduce);
at::Tensor spmm(at::Tensor rowptr, at::Tensor col, at::Tensor val,
at::Tensor mat) {
std::tuple<at::Tensor, at::Tensor>
spmm_arg_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> val,
at::Tensor mat, std::string reduce);
at::Tensor spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> val,
at::Tensor mat, std::string reduce) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
if (val.has_value())
CHECK_CUDA(val.value());
CHECK_CUDA(mat);
return spmm_cuda(rowptr, col, val, mat, reduce);
}
std::tuple<at::Tensor, at::Tensor> spmm_arg(at::Tensor rowptr, at::Tensor col,
at::optional<at::Tensor> val,
at::Tensor mat,
std::string reduce) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(val);
if (val.has_value())
CHECK_CUDA(val.value());
CHECK_CUDA(mat);
return spmm_cuda(rowptr, col, val, mat);
return spmm_arg_cuda(rowptr, col, val, mat, reduce);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spmm", &spmm, "Sparse Matrix Multiplication (CUDA)");
m.def("spmm_arg", &spmm_arg, "Sparse Matrix Multiplication With Arg (CUDA)");
}
......@@ -3,24 +3,27 @@
#include "compat.cuh"
#define Y_SIZE 32
#define THREADS 256
#define ADD 0
#define MEAN 1
#define MIN 2
#define MAX 3
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code: https://github.com/owensgroup/merge-spmm
template <typename scalar_t>
__global__ void
spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
const scalar_t *val_data, const scalar_t *mat_data,
scalar_t *out_data, size_t N, size_t K) {
// We ignore blockIdx.y here, because threads across blockIdx.y operate on the
// same row.
template <typename scalar_t, int64_t REDUCE, bool HAS_VAL>
__global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
const scalar_t *val_data, const scalar_t *mat_data,
scalar_t *out_data, int64_t *arg_out_data, size_t N,
size_t K) {
// We ignore blockIdx.y here, because threads
// across `blockIdx.y` are treated equally.
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int warp_idx = thread_idx >> 5; // thread_id / 32
int row = thread_idx >> 5; // thread_id / 32
int lane_idx = thread_idx & (32 - 1); // thread_id % 32
int row = warp_idx; // Each warp processes exactly one row.
// Compute the column index of `mat` in which the thread is operating.
int mat_col_idx = lane_idx + (blockIdx.y << 5);
......@@ -29,9 +32,10 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
int out_idx = row * K + lane_idx + (blockIdx.y << 5);
// Helper arrays for warp communication.
int mat_row_all[Y_SIZE];
scalar_t val_all[Y_SIZE];
int mat_rows[32];
scalar_t vals[32];
// Do not aggregate/write across the Y-axis (lane_idx < leftover).
int leftover = K - (blockIdx.y << 5);
if (row < N) {
......@@ -39,17 +43,27 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
int row_end = __ldg(rowptr_data + row + 1);
int col_idx = row_start + lane_idx;
int mat_row = -1;
scalar_t val = (scalar_t)0;
scalar_t sum = (scalar_t)0;
// Iterate over all col indices in parallel with 32 threads.
int mat_row;
scalar_t val, result;
int64_t arg_result = -1;
// Dependent on `reduce`, we need to initialize `result` accordingly.
if (REDUCE == ADD)
result = (scalar_t)0;
else if (REDUCE == MEAN)
result = (scalar_t)0;
else if (REDUCE == MIN)
result = std::numeric_limits<scalar_t>::max();
else if (REDUCE == MAX)
result = std::numeric_limits<scalar_t>::min();
// Iterate over all col indices in parallel within a warp.
for (int c = row_start; c < row_end; c += 32) {
if (col_idx < row_end) {
// Coalesced memory access into `col` and `val`.
mat_row = __ldg(col_data + col_idx) * K;
val = __ldg(val_data + col_idx);
val = HAS_VAL ? __ldg(val_data + col_idx) : (scalar_t)1;
} else {
mat_row = 0;
val = (scalar_t)0;
......@@ -57,50 +71,143 @@ spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
col_idx += 32;
#pragma unroll
for (int i = 0; i < 32; i += Y_SIZE) {
#pragma unroll
for (int j = 0; j < Y_SIZE; j++) {
// Communication between *all* threads in a warp.
mat_row_all[j] = __shfl_sync(0xffffffff, mat_row, i + j);
val_all[j] = __shfl_sync(0xffffffff, val, i + j);
}
for (int i = 0; i < 32; i++) {
// Communication between all threads in a warp.
mat_rows[i] = __shfl_sync(0xffffffff, mat_row, i);
vals[i] = __shfl_sync(0xffffffff, val, i);
}
#pragma unroll
for (int j = 0; j < Y_SIZE; j++) {
if (lane_idx < leftover) {
// Coalesced memory access into `mat`.
sum += val_all[j] * __ldg(mat_data + mat_row_all[j] + mat_col_idx);
for (int i = 0; i < 32; i++) {
if (lane_idx < leftover && vals[i] != 0) {
// Coalesced memory access into `mat`.
val = vals[i] * __ldg(mat_data + mat_rows[i] + mat_col_idx);
// Aggregate results along row.
if (REDUCE == ADD)
result += val;
else if (REDUCE == MEAN)
result += val;
else if (REDUCE == MIN) {
if (val < result) {
result = val;
arg_result = row_start + i;
}
} else if (REDUCE == MAX) {
if (val > result) {
result = val;
arg_result = row_start + i;
}
}
}
}
}
if (lane_idx < leftover) {
// Coalesced memory access into `out`.
out_data[out_idx] = sum;
// Coalesced write into `out` (dependent on `reduce`).
if (REDUCE == ADD)
out_data[out_idx] = result;
else if (REDUCE == MEAN)
out_data[out_idx] = result / scalar_t(row_end - row_start);
else if (REDUCE == MIN) {
arg_out_data[out_idx] = arg_result;
if (result == std::numeric_limits<scalar_t>::max())
out_data[out_idx] = (scalar_t)0;
else
out_data[out_idx] = result;
} else if (REDUCE == MAX) {
arg_out_data[out_idx] = arg_result;
if (result == std::numeric_limits<scalar_t>::min())
out_data[out_idx] = (scalar_t)0;
else
out_data[out_idx] = result;
}
}
}
}
at::Tensor spmm_cuda(at::Tensor rowptr, at::Tensor col, at::Tensor val,
at::Tensor mat) {
auto N = rowptr.numel() - 1;
at::Tensor spmm_cuda(at::Tensor rowptr, at::Tensor col,
at::optional<at::Tensor> val, at::Tensor mat,
std::string reduce) {
auto N = rowptr.size(0) - 1;
auto K = mat.size(1);
auto out = at::empty({N, K}, mat.options());
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
auto block_dim = dim3(THREADS);
auto grid_dim = dim3((32 * N + THREADS - 1) / THREADS, (K + 31) / 32);
auto block = dim3(THREADS);
auto grid = dim3((32 * N + THREADS - 1) / THREADS, (K + 31) / 32);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_kernel", [&] {
auto val_data = val.DATA_PTR<scalar_t>();
auto mat_data = mat.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
spmm_row_kernel<scalar_t>
<<<grid_dim, block_dim, 0, at::cuda::getCurrentCUDAStream()>>>(
rowptr_data, col_data, val_data, mat_data, out_data, N, K);
if (val.has_value()) {
auto val_data = val.value().DATA_PTR<scalar_t>();
if (reduce == "add")
spmm_kernel<scalar_t, ADD, true><<<grid, block, 0, stream>>>(
rowptr_data, col_data, val_data, mat_data, out_data, nullptr, N, K);
else if (reduce == "mean")
spmm_kernel<scalar_t, MEAN, true><<<grid, block, 0, stream>>>(
rowptr_data, col_data, val_data, mat_data, out_data, nullptr, N, K);
} else {
if (reduce == "add")
spmm_kernel<scalar_t, ADD, false><<<grid, block, 0, stream>>>(
rowptr_data, col_data, nullptr, mat_data, out_data, nullptr, N, K);
else if (reduce == "mean")
spmm_kernel<scalar_t, MEAN, false><<<grid, block, 0, stream>>>(
rowptr_data, col_data, nullptr, mat_data, out_data, nullptr, N, K);
}
});
return out;
}
std::tuple<at::Tensor, at::Tensor>
spmm_arg_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> val,
at::Tensor mat, std::string reduce) {
auto N = rowptr.size(0) - 1;
auto K = mat.size(1);
auto out = at::empty({N, K}, mat.options());
auto arg_out = at::empty({N, K}, rowptr.options());
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
auto arg_out_data = arg_out.DATA_PTR<int64_t>();
auto block = dim3(THREADS);
auto grid = dim3((32 * N + THREADS - 1) / THREADS, (K + 31) / 32);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_kernel", [&] {
auto mat_data = mat.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
if (val.has_value()) {
auto val_data = val.value().DATA_PTR<scalar_t>();
if (reduce == "min")
spmm_kernel<scalar_t, MIN, true><<<grid, block, 0, stream>>>(
rowptr_data, col_data, val_data, mat_data, out_data, arg_out_data,
N, K);
else if (reduce == "max")
spmm_kernel<scalar_t, MAX, true><<<grid, block, 0, stream>>>(
rowptr_data, col_data, val_data, mat_data, out_data, arg_out_data,
N, K);
} else {
if (reduce == "min")
spmm_kernel<scalar_t, MIN, false><<<grid, block, 0, stream>>>(
rowptr_data, col_data, nullptr, mat_data, out_data, arg_out_data, N,
K);
else if (reduce == "max")
spmm_kernel<scalar_t, MAX, false><<<grid, block, 0, stream>>>(
rowptr_data, col_data, nullptr, mat_data, out_data, arg_out_data, N,
K);
}
});
return std::make_tuple(out, arg_out);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment