Commit 6a7f10e5 authored by rusty1s's avatar rusty1s
Browse files

matmul complete

parent 0fd716cb
...@@ -174,46 +174,56 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt, ...@@ -174,46 +174,56 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
} }
at::Tensor spmm_val_bw(at::Tensor rowptr, at::Tensor col, at::Tensor mat, at::Tensor spmm_val_bw(at::Tensor index, at::Tensor rowptr, at::Tensor mat,
at::Tensor grad, std::string reduce) { at::Tensor grad, std::string reduce) {
CHECK_CPU(index);
CHECK_CPU(rowptr); CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_CPU(mat); CHECK_CPU(mat);
CHECK_CPU(grad); CHECK_CPU(grad);
AT_ASSERTM(index.dim() == 2, "Input mismatch");
AT_ASSERTM(index.size(0) == 2, "Input mismatch");
AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
AT_ASSERTM(mat.dim() == grad.dim(), "Input mismatch");
AT_ASSERTM(reduce2REDUCE.at(reduce) == SUM ||
reduce2REDUCE.at(reduce) == MEAN,
"Reduce operation not supported");
index = index.contiguous();
mat = mat.contiguous(); mat = mat.contiguous();
grad = grad.contiguous();
auto M = rowptr.numel() - 1; auto M = grad.size(-2);
auto N = mat.size(-2); auto N = mat.size(-2);
auto E = index.size(1);
auto K = mat.size(-1); auto K = mat.size(-1);
auto B = mat.numel() / (N * K); auto B = mat.numel() / (N * K);
auto out = at::zeros(col.sizes(), grad.options()); auto out = at::zeros(index.size(1), grad.options());
auto index_data = index.DATA_PTR<int64_t>();
auto rowptr_data = rowptr.DATA_PTR<int64_t>(); auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw", [&] { AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw", [&] {
auto mat_data = mat.DATA_PTR<scalar_t>(); auto mat_data = mat.DATA_PTR<scalar_t>();
auto grad_data = grad.DATA_PTR<scalar_t>(); auto grad_data = grad.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>(); auto out_data = out.DATA_PTR<scalar_t>();
scalar_t val; scalar_t val;
int64_t row_start, row_end, c; int64_t row, col;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int b = 0; b < B; b++) { for (int b = 0; b < B; b++) {
for (int m = 0; m < M; m++) { for (int e = 0; e < E; e++) {
row_start = rowptr_data[m], row_end = rowptr_data[m + 1]; row = index_data[e], col = index_data[E + e], val = (scalar_t)0;
for (int e = row_start; e < row_end; e++) {
c = col_data[e], val = (scalar_t)0;
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
val += mat_data[b * N * K + c * K + k] * val += mat_data[b * N * K + col * K + k] *
grad_data[b * M * K + m * K + k]; grad_data[b * M * K + row * K + k];
} }
if (REDUCE == MEAN) if (REDUCE == MEAN) {
val = val / (scalar_t)(row_end - row_start); int row_start = rowptr_data[row], row_end = rowptr_data[row + 1];
out_data[e] += val; val /= (scalar_t)std::max(row_end - row_start, 1);
} }
out_data[e] += val;
} }
} }
}); });
......
...@@ -6,6 +6,9 @@ std::tuple<at::Tensor, at::optional<at::Tensor>> ...@@ -6,6 +6,9 @@ std::tuple<at::Tensor, at::optional<at::Tensor>>
spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt, spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
at::Tensor mat, std::string reduce); at::Tensor mat, std::string reduce);
at::Tensor spmm_val_bw_cuda(at::Tensor index, at::Tensor rowptr, at::Tensor mat,
at::Tensor grad, std::string reduce);
std::tuple<at::Tensor, at::optional<at::Tensor>> std::tuple<at::Tensor, at::optional<at::Tensor>>
spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt, spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
at::Tensor mat, std::string reduce) { at::Tensor mat, std::string reduce) {
...@@ -17,6 +20,17 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt, ...@@ -17,6 +20,17 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return spmm_cuda(rowptr, col, value_opt, mat, reduce); return spmm_cuda(rowptr, col, value_opt, mat, reduce);
} }
at::Tensor spmm_val_bw(at::Tensor index, at::Tensor rowptr, at::Tensor mat,
at::Tensor grad, std::string reduce) {
CHECK_CUDA(index);
CHECK_CUDA(rowptr);
CHECK_CUDA(mat);
CHECK_CUDA(grad);
return spmm_val_bw_cuda(index, rowptr, mat, grad, reduce);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spmm", &spmm, "Sparse Matrix Multiplication (CUDA)"); m.def("spmm", &spmm, "Sparse Matrix Multiplication (CUDA)");
m.def("spmm_val_bw", &spmm_val_bw,
"Sparse-Dense Matrix Multiplication Value Backward (CPU)");
} }
...@@ -99,12 +99,11 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data, ...@@ -99,12 +99,11 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
// Helper arrays for warp communication. // Helper arrays for warp communication.
int mat_row, mat_rows[32]; int mat_row, mat_rows[32];
scalar_t val, vals[HAS_VAL ? 32 : 1]; scalar_t val, vals[HAS_VAL ? 32 : 1];
int bla, blas[32];
// Do not aggregate/write across the Y-axis (lane_idx < leftover). // Do not aggregate/write across the Y-axis (lane_idx < leftover).
int leftover = K - (blockIdx.y << 5); int leftover = K - (blockIdx.y << 5);
if (row < B * M) { if (batch_idx < B) {
int row_start = __ldg(rowptr_data + (row % M)); int row_start = __ldg(rowptr_data + (row % M));
int row_end = __ldg(rowptr_data + (row % M) + 1); int row_end = __ldg(rowptr_data + (row % M) + 1);
int col_idx = row_start + lane_idx; int col_idx = row_start + lane_idx;
...@@ -118,12 +117,10 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data, ...@@ -118,12 +117,10 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
if (col_idx < row_end) { if (col_idx < row_end) {
// Coalesced memory access into `col` and `val`. // Coalesced memory access into `col` and `val`.
mat_row = __ldg(col_data + col_idx) * K; mat_row = __ldg(col_data + col_idx) * K;
bla = col_idx;
if (HAS_VAL) if (HAS_VAL)
val = __ldg(value_data + col_idx); val = __ldg(value_data + col_idx);
} else { } else {
mat_row = -1; mat_row = -1;
bla = -1;
if (HAS_VAL) if (HAS_VAL)
val = (scalar_t)0; val = (scalar_t)0;
} }
...@@ -133,7 +130,6 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data, ...@@ -133,7 +130,6 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
for (int i = 0; i < 32; i++) { for (int i = 0; i < 32; i++) {
// Communication between all threads in a warp. // Communication between all threads in a warp.
mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i); mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i);
blas[i] = __shfl_sync(FULL_MASK, bla, i);
if (HAS_VAL) if (HAS_VAL)
vals[i] = __shfl_sync(FULL_MASK, val, i); vals[i] = __shfl_sync(FULL_MASK, val, i);
} }
...@@ -182,9 +178,6 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt, ...@@ -182,9 +178,6 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
arg_out_data = arg_out.value().DATA_PTR<int64_t>(); arg_out_data = arg_out.value().DATA_PTR<int64_t>();
} }
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
auto M = rowptr.numel() - 1; auto M = rowptr.numel() - 1;
auto N = mat.size(-2); auto N = mat.size(-2);
auto K = mat.size(-1); auto K = mat.size(-1);
...@@ -193,6 +186,8 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt, ...@@ -193,6 +186,8 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_kernel", [&] { AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_kernel", [&] {
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
auto mat_data = mat.DATA_PTR<scalar_t>(); auto mat_data = mat.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>(); auto out_data = out.DATA_PTR<scalar_t>();
...@@ -212,3 +207,84 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt, ...@@ -212,3 +207,84 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
} }
template <typename scalar_t, ReductionType REDUCE>
__global__ void
spmm_val_bw_kernel(const int64_t *index_data, const int64_t *rowptr_data,
const scalar_t *mat_data, const scalar_t *grad_data,
scalar_t *out_data, int B, int M, int N, int E, int K) {
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int index_idx = (thread_idx >> 5); // thread_idx / 32
int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
if (index_idx < E) {
int row = __ldg(index_data + index_idx);
int col = __ldg(index_data + E + index_idx);
scalar_t val = (scalar_t)0;
for (int b = 0; b < B; b++) {
for (int k = lane_idx; k < K; k += 32) {
val += mat_data[b * N * K + col * K + k] *
grad_data[b * M * K + row * K + k];
}
}
#pragma unroll
for (int i = 32 / 2; i > 0; i /= 2) { // Parallel reduction inside a warp.
val += __shfl_down_sync(FULL_MASK, val, i);
}
if (lane_idx == 0) {
if (REDUCE == MEAN) {
int row_start = __ldg(rowptr_data + row);
int row_end = __ldg(rowptr_data + row + 1);
val /= (scalar_t)max(row_end - row_start, 1);
}
out_data[index_idx] = val;
}
}
}
at::Tensor spmm_val_bw_cuda(at::Tensor index, at::Tensor rowptr, at::Tensor mat,
at::Tensor grad, std::string reduce) {
AT_ASSERTM(index.dim() == 2, "Input mismatch");
AT_ASSERTM(index.size(0) == 2, "Input mismatch");
AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
AT_ASSERTM(mat.dim() == grad.dim(), "Input mismatch");
AT_ASSERTM(reduce2REDUCE.at(reduce) == SUM ||
reduce2REDUCE.at(reduce) == MEAN,
"Reduce operation not supported");
index = index.contiguous();
mat = mat.contiguous();
grad = grad.contiguous();
auto M = grad.size(-2);
auto N = mat.size(-2);
auto E = index.size(1);
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((E * 32 + THREADS - 1) / THREADS);
auto out = at::empty(index.size(1), grad.options());
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw_kernel", [&] {
auto index_data = index.DATA_PTR<int64_t>();
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto mat_data = mat.DATA_PTR<scalar_t>();
auto grad_data = grad.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
spmm_val_bw_kernel<scalar_t, REDUCE>
<<<BLOCKS, THREADS, 0, stream>>>(index_data, rowptr_data, mat_data,
grad_data, out_data, B, M, N, E, K);
});
});
return out;
}
...@@ -9,10 +9,7 @@ import torch_scatter ...@@ -9,10 +9,7 @@ import torch_scatter
from .utils import devices, grad_dtypes from .utils import devices, grad_dtypes
devices = ['cpu', 'cuda']
grad_dtypes = [torch.float]
reductions = ['sum', 'mean', 'min', 'max'] reductions = ['sum', 'mean', 'min', 'max']
reductions = ['min', 'max']
@pytest.mark.parametrize('dtype,device,reduce', @pytest.mark.parametrize('dtype,device,reduce',
......
...@@ -44,11 +44,11 @@ class SPMM(torch.autograd.Function): ...@@ -44,11 +44,11 @@ class SPMM(torch.autograd.Function):
if ctx.needs_input_grad[5]: if ctx.needs_input_grad[5]:
if ctx.reduce in ['sum', 'add']: if ctx.reduce in ['sum', 'add']:
grad_value = spmm(grad_out.is_cuda).spmm_val_bw( grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
rowptr, index[1], mat, grad_out, ctx.reduce) index, rowptr, mat, grad_out, ctx.reduce)
if ctx.reduce == 'mean': if ctx.reduce == 'mean':
grad_value = spmm(grad_out.is_cuda).spmm_val_bw( grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
rowptr, index[1], mat, grad_out, ctx.reduce) index, rowptr, mat, grad_out, ctx.reduce)
elif ctx.reduce in ['min', 'max']: elif ctx.reduce in ['min', 'max']:
col = index[1][arg_out_ind.flatten()].view_as(arg_out) col = index[1][arg_out_ind.flatten()].view_as(arg_out)
...@@ -108,5 +108,6 @@ def matmul(src, other, reduce='sum'): ...@@ -108,5 +108,6 @@ def matmul(src, other, reduce='sum'):
elif isinstance(other, src.__class__): elif isinstance(other, src.__class__):
assert reduce in ['sum', 'add'] assert reduce in ['sum', 'add']
raise NotImplementedError
raise ValueError raise ValueError
...@@ -12,6 +12,7 @@ from torch_sparse.index_select import index_select, index_select_nnz ...@@ -12,6 +12,7 @@ from torch_sparse.index_select import index_select, index_select_nnz
from torch_sparse.masked_select import masked_select, masked_select_nnz from torch_sparse.masked_select import masked_select, masked_select_nnz
import torch_sparse.reduce import torch_sparse.reduce
from torch_sparse.diag import remove_diag from torch_sparse.diag import remove_diag
from torch_sparse.matmul import matmul
class SparseTensor(object): class SparseTensor(object):
...@@ -410,6 +411,9 @@ class SparseTensor(object): ...@@ -410,6 +411,9 @@ class SparseTensor(object):
return out return out
def __matmul__(a, b):
return matmul(a, b, reduce='sum')
# String Reputation ####################################################### # String Reputation #######################################################
def __repr__(self): def __repr__(self):
...@@ -446,6 +450,7 @@ SparseTensor.mean = torch_sparse.reduce.mean ...@@ -446,6 +450,7 @@ SparseTensor.mean = torch_sparse.reduce.mean
SparseTensor.min = torch_sparse.reduce.min SparseTensor.min = torch_sparse.reduce.min
SparseTensor.max = torch_sparse.reduce.max SparseTensor.max = torch_sparse.reduce.max
SparseTensor.remove_diag = remove_diag SparseTensor.remove_diag = remove_diag
SparseTensor.matmul = matmul
# SparseTensor.add = add # SparseTensor.add = add
# SparseTensor.add_nnz = add_nnz # SparseTensor.add_nnz = add_nnz
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment