Commit 36d045fd authored by rusty1s's avatar rusty1s
Browse files

sparse matrix multiplication kernel

parent 51834e88
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor spmm_cuda(at::Tensor rowptr, at::Tensor col, at::Tensor val,
at::Tensor mat);
at::Tensor spmm(at::Tensor rowptr, at::Tensor col, at::Tensor val,
at::Tensor mat) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(val);
CHECK_CUDA(mat);
return spmm_cuda(rowptr, col, val, mat);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spmm", &spmm, "Sparse Matrix Multiplication (CUDA)");
}
#include <ATen/ATen.h>
#include "compat.cuh"
#define THREADS 32 * 16
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code: https://github.com/owensgroup/merge-spmm
template <typename scalar_t, size_t Y_SIZE>
__global__ void
spmm_row_kernel(const int64_t *rowptr_data, const int64_t *col_data,
const scalar_t *val_data, const scalar_t *mat_data,
scalar_t *out_data, size_t N, size_t M, size_t K) {
// We ignore blockIdx.y here, because threads across blockIdx.y operate on the
// same row.
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int warp_idx = thread_idx >> 5; // thread_id / 32
int lane_idx = thread_idx & (32 - 1); // thread_id % 32
int row = warp_idx; // Each warp processes exactly one row.
// Compute the column index of `mat` in which the thread is operating.
int mat_col_idx = lane_idx + (blockIdx.y << 5);
// Compute the output index given in row-major order.
int out_idx = row * K + lane_idx + (blockIdx.y << 5);
// Helper arrays for warp communication.
int mat_row_all[Y_SIZE];
scalar_t val_all[Y_SIZE];
int leftover = K - (blockIdx.y << 5);
if (row < N) {
int row_start = __ldg(rowptr_data + row);
int row_end = __ldg(rowptr_data + row + 1);
// Iterate over all col indices in parallel.
for (int col_idx = row_start + lane_idx; col_idx < row_end; col_idx += 32) {
int mat_row = __ldg(col_data + col_idx) * K;
int val = __ldg(val_data + col_idx);
scalar_t sum = (scalar_t)0;
for (int i = 0; i < 32; i += Y_SIZE) {
#pragma unroll
for (int j = 0; j < Y_SIZE; j++) {
// Warp communication with *all* threads (mask = 0xffffffff).
// TODO: Compute real bit mask via `__ballot_sync()`.
mat_row_all[j] = __shfl_sync(0xffffffff, mat_row, i + j);
val_all[j] = __shfl_sync(0xffffffff, val, i + j);
}
#pragma unroll
for (int j = 0; j < Y_SIZE; j++) {
if (lane_idx < leftover) {
// Coalesced memory access into `mat`.
sum += val_all[j] * __ldg(mat_data + mat_row_all[j] + mat_col_idx);
}
}
}
if (lane_idx < leftover) {
out_data[out_idx] = sum;
}
}
}
}
at::Tensor spmm_cuda(at::Tensor rowptr, at::Tensor col, at::Tensor val,
at::Tensor mat) {
// TODO: Set device
auto N = rowptr.numel() - 1;
auto M = mat.size(0);
auto K = mat.size(1);
auto out = at::empty({N, K}, mat.options());
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
auto val_data = val.DATA_PTR<float>();
auto mat_data = mat.DATA_PTR<float>();
auto out_data = out.DATA_PTR<float>();
auto block_dim = dim3(THREADS);
auto grid_dim = dim3((N + THREADS - 1) / THREADS, (K + 32 - 1) / 32);
spmm_row_kernel<float, 32><<<grid_dim, block_dim, 0 /*, cuda_stream */>>>(
rowptr_data, col_data, val_data, mat_data, out_data, N, M, K);
return out;
}
...@@ -30,6 +30,10 @@ if CUDA_HOME is not None and GPU: ...@@ -30,6 +30,10 @@ if CUDA_HOME is not None and GPU:
extra_link_args = ['-lcusparse', '-l', 'cusparse'] extra_link_args = ['-lcusparse', '-l', 'cusparse']
ext_modules += [ ext_modules += [
CUDAExtension('torch_sparse.spmm_cuda',
['cuda/spmm.cpp', 'cuda/spmm_kernel.cu'],
extra_link_args=extra_link_args,
extra_compile_args=extra_compile_args),
CUDAExtension('torch_sparse.spspmm_cuda', CUDAExtension('torch_sparse.spspmm_cuda',
['cuda/spspmm.cpp', 'cuda/spspmm_kernel.cu'], ['cuda/spspmm.cpp', 'cuda/spspmm_kernel.cu'],
extra_link_args=extra_link_args, extra_link_args=extra_link_args,
......
...@@ -59,11 +59,11 @@ class SparseTensor(object): ...@@ -59,11 +59,11 @@ class SparseTensor(object):
return self._index, self._value return self._index, self._value
def csr(self): def csr(self):
return self._col, self._rowptr, self._value return self._rowptr, self._col, self._value
def csc(self): def csc(self):
perm = self._arg_csr_to_csc perm = self._arg_csr_to_csc
return self._row[perm], self._colptr, self._value[perm] return self._colptr, self._row[perm], self._value[perm]
def is_quadratic(self): def is_quadratic(self):
return self.sparse_size[0] == self.sparse_size[1] return self.sparse_size[0] == self.sparse_size[1]
...@@ -103,24 +103,26 @@ class SparseTensor(object): ...@@ -103,24 +103,26 @@ class SparseTensor(object):
return self.__class__.from_storage(storage) return self.__class__.from_storage(storage)
def matmul(self, mat2): def matmul(self, mat2):
pass raise NotImplementedError
def coalesce(self, reduce='add'): def coalesce(self, reduce='add'):
pass raise NotImplementedError
def is_coalesced(self): def is_coalesced(self):
pass raise NotImplementedError
def add(self, layout=None): def add(self, layout=None):
# sub, mul, div # sub, mul, div
# can take scalars, tensors and other sparse matrices # can take scalars, tensors and other sparse matrices
# inplace variants can only take scalars or tensors # inplace variants can only take scalars or tensors
pass raise NotImplementedError
# TODO: Slicing, (sum|max|min|prod|...), standard operators, masing, perm
def to_dense(self, dtype=None): def to_dense(self, dtype=None):
dtype = dtype or self.dtype dtype = dtype or self.dtype
mat = torch.zeros(self.size(), dtype=dtype, device=self.device) mat = torch.zeros(self.size(), dtype=dtype, device=self.device)
mat[self._row, self._col] = self._value or 1 mat[self._row, self._col] = self._value if self.has_value else 1
return mat return mat
def to_scipy(self): def to_scipy(self):
...@@ -129,8 +131,6 @@ class SparseTensor(object): ...@@ -129,8 +131,6 @@ class SparseTensor(object):
def to_torch_sparse_coo_tensor(self): def to_torch_sparse_coo_tensor(self):
raise NotImplementedError raise NotImplementedError
# TODO: Slicing, (sum|max|min|prod|...), standard operators, masing, perm
def __repr__(self): def __repr__(self):
i = ' ' * 6 i = ' ' * 6
index, value = self.coo() index, value = self.coo()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment