Commit 6a7f10e5 authored by rusty1s's avatar rusty1s
Browse files

matmul complete

parent 0fd716cb
......@@ -174,46 +174,56 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return std::make_tuple(out, arg_out);
}
at::Tensor spmm_val_bw(at::Tensor rowptr, at::Tensor col, at::Tensor mat,
at::Tensor spmm_val_bw(at::Tensor index, at::Tensor rowptr, at::Tensor mat,
at::Tensor grad, std::string reduce) {
CHECK_CPU(index);
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_CPU(mat);
CHECK_CPU(grad);
AT_ASSERTM(index.dim() == 2, "Input mismatch");
AT_ASSERTM(index.size(0) == 2, "Input mismatch");
AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
AT_ASSERTM(mat.dim() == grad.dim(), "Input mismatch");
AT_ASSERTM(reduce2REDUCE.at(reduce) == SUM ||
reduce2REDUCE.at(reduce) == MEAN,
"Reduce operation not supported");
index = index.contiguous();
mat = mat.contiguous();
grad = grad.contiguous();
auto M = rowptr.numel() - 1;
auto M = grad.size(-2);
auto N = mat.size(-2);
auto E = index.size(1);
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto out = at::zeros(col.sizes(), grad.options());
auto out = at::zeros(index.size(1), grad.options());
auto index_data = index.DATA_PTR<int64_t>();
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw", [&] {
auto mat_data = mat.DATA_PTR<scalar_t>();
auto grad_data = grad.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
scalar_t val;
int64_t row_start, row_end, c;
int64_t row, col;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int b = 0; b < B; b++) {
for (int m = 0; m < M; m++) {
row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
for (int e = row_start; e < row_end; e++) {
c = col_data[e], val = (scalar_t)0;
for (int k = 0; k < K; k++) {
val += mat_data[b * N * K + c * K + k] *
grad_data[b * M * K + m * K + k];
}
if (REDUCE == MEAN)
val = val / (scalar_t)(row_end - row_start);
out_data[e] += val;
for (int e = 0; e < E; e++) {
row = index_data[e], col = index_data[E + e], val = (scalar_t)0;
for (int k = 0; k < K; k++) {
val += mat_data[b * N * K + col * K + k] *
grad_data[b * M * K + row * K + k];
}
if (REDUCE == MEAN) {
int row_start = rowptr_data[row], row_end = rowptr_data[row + 1];
val /= (scalar_t)std::max(row_end - row_start, 1);
}
out_data[e] += val;
}
}
});
......
......@@ -6,6 +6,9 @@ std::tuple<at::Tensor, at::optional<at::Tensor>>
spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
at::Tensor mat, std::string reduce);
at::Tensor spmm_val_bw_cuda(at::Tensor index, at::Tensor rowptr, at::Tensor mat,
at::Tensor grad, std::string reduce);
std::tuple<at::Tensor, at::optional<at::Tensor>>
spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
at::Tensor mat, std::string reduce) {
......@@ -17,6 +20,17 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return spmm_cuda(rowptr, col, value_opt, mat, reduce);
}
at::Tensor spmm_val_bw(at::Tensor index, at::Tensor rowptr, at::Tensor mat,
at::Tensor grad, std::string reduce) {
CHECK_CUDA(index);
CHECK_CUDA(rowptr);
CHECK_CUDA(mat);
CHECK_CUDA(grad);
return spmm_val_bw_cuda(index, rowptr, mat, grad, reduce);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spmm", &spmm, "Sparse Matrix Multiplication (CUDA)");
m.def("spmm_val_bw", &spmm_val_bw,
"Sparse-Dense Matrix Multiplication Value Backward (CPU)");
}
......@@ -99,12 +99,11 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
// Helper arrays for warp communication.
int mat_row, mat_rows[32];
scalar_t val, vals[HAS_VAL ? 32 : 1];
int bla, blas[32];
// Do not aggregate/write across the Y-axis (lane_idx < leftover).
int leftover = K - (blockIdx.y << 5);
if (row < B * M) {
if (batch_idx < B) {
int row_start = __ldg(rowptr_data + (row % M));
int row_end = __ldg(rowptr_data + (row % M) + 1);
int col_idx = row_start + lane_idx;
......@@ -118,12 +117,10 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
if (col_idx < row_end) {
// Coalesced memory access into `col` and `val`.
mat_row = __ldg(col_data + col_idx) * K;
bla = col_idx;
if (HAS_VAL)
val = __ldg(value_data + col_idx);
} else {
mat_row = -1;
bla = -1;
if (HAS_VAL)
val = (scalar_t)0;
}
......@@ -133,7 +130,6 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
for (int i = 0; i < 32; i++) {
// Communication between all threads in a warp.
mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i);
blas[i] = __shfl_sync(FULL_MASK, bla, i);
if (HAS_VAL)
vals[i] = __shfl_sync(FULL_MASK, val, i);
}
......@@ -182,9 +178,6 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
auto M = rowptr.numel() - 1;
auto N = mat.size(-2);
auto K = mat.size(-1);
......@@ -193,6 +186,8 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_kernel", [&] {
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
auto mat_data = mat.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
......@@ -212,3 +207,84 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return std::make_tuple(out, arg_out);
}
template <typename scalar_t, ReductionType REDUCE>
__global__ void
spmm_val_bw_kernel(const int64_t *index_data, const int64_t *rowptr_data,
const scalar_t *mat_data, const scalar_t *grad_data,
scalar_t *out_data, int B, int M, int N, int E, int K) {
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int index_idx = (thread_idx >> 5); // thread_idx / 32
int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
if (index_idx < E) {
int row = __ldg(index_data + index_idx);
int col = __ldg(index_data + E + index_idx);
scalar_t val = (scalar_t)0;
for (int b = 0; b < B; b++) {
for (int k = lane_idx; k < K; k += 32) {
val += mat_data[b * N * K + col * K + k] *
grad_data[b * M * K + row * K + k];
}
}
#pragma unroll
for (int i = 32 / 2; i > 0; i /= 2) { // Parallel reduction inside a warp.
val += __shfl_down_sync(FULL_MASK, val, i);
}
if (lane_idx == 0) {
if (REDUCE == MEAN) {
int row_start = __ldg(rowptr_data + row);
int row_end = __ldg(rowptr_data + row + 1);
val /= (scalar_t)max(row_end - row_start, 1);
}
out_data[index_idx] = val;
}
}
}
at::Tensor spmm_val_bw_cuda(at::Tensor index, at::Tensor rowptr, at::Tensor mat,
at::Tensor grad, std::string reduce) {
AT_ASSERTM(index.dim() == 2, "Input mismatch");
AT_ASSERTM(index.size(0) == 2, "Input mismatch");
AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
AT_ASSERTM(mat.dim() == grad.dim(), "Input mismatch");
AT_ASSERTM(reduce2REDUCE.at(reduce) == SUM ||
reduce2REDUCE.at(reduce) == MEAN,
"Reduce operation not supported");
index = index.contiguous();
mat = mat.contiguous();
grad = grad.contiguous();
auto M = grad.size(-2);
auto N = mat.size(-2);
auto E = index.size(1);
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((E * 32 + THREADS - 1) / THREADS);
auto out = at::empty(index.size(1), grad.options());
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw_kernel", [&] {
auto index_data = index.DATA_PTR<int64_t>();
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto mat_data = mat.DATA_PTR<scalar_t>();
auto grad_data = grad.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
spmm_val_bw_kernel<scalar_t, REDUCE>
<<<BLOCKS, THREADS, 0, stream>>>(index_data, rowptr_data, mat_data,
grad_data, out_data, B, M, N, E, K);
});
});
return out;
}
......@@ -9,10 +9,7 @@ import torch_scatter
from .utils import devices, grad_dtypes
devices = ['cpu', 'cuda']
grad_dtypes = [torch.float]
reductions = ['sum', 'mean', 'min', 'max']
reductions = ['min', 'max']
@pytest.mark.parametrize('dtype,device,reduce',
......
......@@ -44,11 +44,11 @@ class SPMM(torch.autograd.Function):
if ctx.needs_input_grad[5]:
if ctx.reduce in ['sum', 'add']:
grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
rowptr, index[1], mat, grad_out, ctx.reduce)
index, rowptr, mat, grad_out, ctx.reduce)
if ctx.reduce == 'mean':
grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
rowptr, index[1], mat, grad_out, ctx.reduce)
index, rowptr, mat, grad_out, ctx.reduce)
elif ctx.reduce in ['min', 'max']:
col = index[1][arg_out_ind.flatten()].view_as(arg_out)
......@@ -108,5 +108,6 @@ def matmul(src, other, reduce='sum'):
elif isinstance(other, src.__class__):
assert reduce in ['sum', 'add']
raise NotImplementedError
raise ValueError
......@@ -12,6 +12,7 @@ from torch_sparse.index_select import index_select, index_select_nnz
from torch_sparse.masked_select import masked_select, masked_select_nnz
import torch_sparse.reduce
from torch_sparse.diag import remove_diag
from torch_sparse.matmul import matmul
class SparseTensor(object):
......@@ -410,6 +411,9 @@ class SparseTensor(object):
return out
def __matmul__(a, b):
return matmul(a, b, reduce='sum')
# String Reputation #######################################################
def __repr__(self):
......@@ -446,6 +450,7 @@ SparseTensor.mean = torch_sparse.reduce.mean
SparseTensor.min = torch_sparse.reduce.min
SparseTensor.max = torch_sparse.reduce.max
SparseTensor.remove_diag = remove_diag
SparseTensor.matmul = matmul
# SparseTensor.add = add
# SparseTensor.add_nnz = add_nnz
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment