Commit fc650310 authored by rusty1s's avatar rusty1s
Browse files

spspmm kernel

parent 335dfed0
......@@ -2,37 +2,61 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
std::tuple<at::Tensor, at::Tensor>
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
at::Tensor valueB, size_t m, size_t k, size_t n);
at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
at::Tensor valueA, at::Tensor indexB,
at::Tensor valueB, size_t rowA_max, size_t rowB_max);
std::tuple<at::Tensor, at::Tensor, at::optional<at::Tensor>>
spspmm_cuda(at::Tensor rowptrA, at::Tensor colA,
at::optional<at::Tensor> valueA, at::Tensor rowptrB,
at::Tensor colB, at::optional<at::Tensor> valueB, int M, int N,
int K);
std::tuple<at::Tensor, at::Tensor> spspmm(at::Tensor indexA, at::Tensor valueA,
at::Tensor indexB, at::Tensor valueB,
size_t m, size_t k, size_t n) {
CHECK_CUDA(indexA);
CHECK_CUDA(valueA);
CHECK_CUDA(indexB);
CHECK_CUDA(valueB);
return spspmm_cuda(indexA, valueA, indexB, valueB, m, k, n);
std::tuple<at::Tensor, at::Tensor, at::optional<at::Tensor>>
spspmm(at::Tensor rowptrA, at::Tensor colA, at::optional<at::Tensor> valueA,
at::Tensor rowptrB, at::Tensor colB, at::optional<at::Tensor> valueB,
int M, int N, int K) {
CHECK_CUDA(rowptrA);
CHECK_CUDA(colA);
if (valueA.has_value())
CHECK_CUDA(valueA.value());
CHECK_CUDA(rowptrB);
CHECK_CUDA(colB);
if (valueB.has_value())
CHECK_CUDA(valueB.value());
return spspmm_cuda(rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K);
}
at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
at::Tensor indexB, at::Tensor valueB, size_t rowA_max,
size_t rowB_max) {
CHECK_CUDA(index);
CHECK_CUDA(indexA);
CHECK_CUDA(valueA);
CHECK_CUDA(indexB);
CHECK_CUDA(valueB);
return spspmm_bw_cuda(index, indexA, valueA, indexB, valueB, rowA_max,
rowB_max);
}
// std::tuple<at::Tensor, at::Tensor>
// spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
// at::Tensor valueB, size_t m, size_t k, size_t n);
// at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
// at::Tensor valueA, at::Tensor indexB,
// at::Tensor valueB, size_t rowA_max, size_t
// rowB_max);
// std::tuple<at::Tensor, at::Tensor> spspmm(at::Tensor indexA, at::Tensor
// valueA,
// at::Tensor indexB, at::Tensor
// valueB, size_t m, size_t k, size_t
// n) {
// CHECK_CUDA(indexA);
// CHECK_CUDA(valueA);
// CHECK_CUDA(indexB);
// CHECK_CUDA(valueB);
// return spspmm_cuda(indexA, valueA, indexB, valueB, m, k, n);
// }
// at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
// at::Tensor indexB, at::Tensor valueB, size_t rowA_max,
// size_t rowB_max) {
// CHECK_CUDA(index);
// CHECK_CUDA(indexA);
// CHECK_CUDA(valueA);
// CHECK_CUDA(indexB);
// CHECK_CUDA(valueB);
// return spspmm_bw_cuda(index, indexA, valueA, indexB, valueB, rowA_max,
// rowB_max);
// }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spspmm", &spspmm, "Sparse-Sparse Matrix Multiplication (CUDA)");
m.def("spspmm_bw", &spspmm_bw,
"Sparse-Sparse Matrix Multiplication Backward (CUDA)");
// m.def("spspmm_bw", &spspmm_bw,
// "Sparse-Sparse Matrix Multiplication Backward (CUDA)");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cusparse.h>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define AT_DISPATCH_CUSPARSE_CSR_GEMM2_BUFFER_SIZE_EXT_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseScsrgemm2_bufferSizeExt; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Double: { \
using scalar_t = double; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseDcsrgemm2_bufferSizeExt; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Not implemented for '", toString(TYPE), "'"); \
} \
}()
#define CSRGEMM(TYPE, ...) \
#define AT_DISPATCH_CUSPARSE_CSR_GEMM2_TYPES(TYPE, ...) \
[&] { \
const auto &the_type = TYPE; \
(void)the_type; \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
return cusparseScsrgemm(__VA_ARGS__); \
const auto &cusparsecsrgemm2 = cusparseScsrgemm2; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Double: { \
using scalar_t = double; \
return cusparseDcsrgemm(__VA_ARGS__); \
const auto &cusparsecsrgemm2 = cusparseDcsrgemm2; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Not implemented for '", toString(_st), "'"); \
AT_ERROR("Not implemented for '", toString(TYPE), "'"); \
} \
}()
static cusparseHandle_t cusparse_handle = 0;
std::tuple<at::Tensor, at::Tensor, at::optional<at::Tensor>>
spspmm_cuda(at::Tensor rowptrA, at::Tensor colA,
at::optional<at::Tensor> valueA, at::Tensor rowptrB,
at::Tensor colB, at::optional<at::Tensor> valueB, int M, int N,
int K) {
cusparseMatDescr_t descr = 0;
cusparseCreateMatDescr(&descr);
auto handle = at::cuda::getCurrentCUDASparseHandle();
static void init_cusparse() {
if (cusparse_handle == 0) {
cusparseStatus_t status = cusparseCreate(&cusparse_handle);
}
}
rowptrA = rowptrA.toType(at::kInt), colA = colA.toType(at::kInt);
rowptrB = rowptrB.toType(at::kInt), colB = colB.toType(at::kInt);
std::tuple<at::Tensor, at::Tensor>
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
at::Tensor valueB, size_t m, size_t k, size_t n) {
cudaSetDevice(indexA.get_device());
init_cusparse();
indexA = indexA.contiguous();
valueA = valueA.contiguous();
indexB = indexB.contiguous();
valueB = valueB.contiguous();
auto nnzA = valueA.size(0);
auto nnzB = valueB.size(0);
indexA = indexA.toType(at::kInt);
indexB = indexB.toType(at::kInt);
// Convert A to CSR format.
auto row_ptrA = at::empty(m + 1, indexA.options());
cusparseXcoo2csr(cusparse_handle, indexA[0].DATA_PTR<int>(), nnzA, k,
row_ptrA.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
auto colA = indexA[1];
cudaMemcpy(row_ptrA.DATA_PTR<int>() + m, &nnzA, sizeof(int),
cudaMemcpyHostToDevice);
// Convert B to CSR format.
auto row_ptrB = at::empty(k + 1, indexB.options());
cusparseXcoo2csr(cusparse_handle, indexB[0].DATA_PTR<int>(), nnzB, k,
row_ptrB.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
auto colB = indexB[1];
cudaMemcpy(row_ptrB.DATA_PTR<int>() + k, &nnzB, sizeof(int),
cudaMemcpyHostToDevice);
auto rowptrA_data = rowptrA.DATA_PTR<int>(), colA_data = colA.DATA_PTR<int>();
auto rowptrB_data = rowptrB.DATA_PTR<int>(), colB_data = colB.DATA_PTR<int>();
cusparseMatDescr_t descr = 0;
cusparseCreateMatDescr(&descr);
cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL);
cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);
csrgemm2Info_t info = NULL;
cusparseCreateCsrgemm2Info(&info);
auto scalar_type = at::ScalarType::Float;
if (valueA.has_value())
scalar_type = valueA.value().scalar_type();
if (valueB.has_value())
scalar_type = valueB.value().scalar_type();
size_t bufferSize;
AT_DISPATCH_CUSPARSE_CSR_GEMM2_BUFFER_SIZE_EXT_TYPES(scalar_type, [&] {
scalar_t alpha = (scalar_t)1;
cusparsecsrgemm2_bufferSizeExt(handle, M, N, K, &alpha, descr, colA.numel(),
rowptrA_data, colA_data, descr, colB.numel(),
rowptrB_data, colB_data, NULL, descr, 0,
NULL, NULL, info, &bufferSize);
});
void *buffer = NULL;
cudaMalloc(&buffer, bufferSize);
int nnzC;
auto row_ptrC = at::empty(m + 1, indexB.options());
cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
row_ptrA.DATA_PTR<int>(), colA.DATA_PTR<int>(), descr,
nnzB, row_ptrB.DATA_PTR<int>(), colB.DATA_PTR<int>(),
descr, row_ptrC.DATA_PTR<int>(), &nnzC);
auto colC = at::empty(nnzC, indexA.options());
auto valueC = at::empty(nnzC, valueA.options());
CSRGEMM(valueC.scalar_type(), cusparse_handle,
CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, m,
n, k, descr, nnzA, valueA.DATA_PTR<scalar_t>(),
row_ptrA.DATA_PTR<int>(), colA.DATA_PTR<int>(), descr, nnzB,
valueB.DATA_PTR<scalar_t>(), row_ptrB.DATA_PTR<int>(),
colB.DATA_PTR<int>(), descr, valueC.DATA_PTR<scalar_t>(),
row_ptrC.DATA_PTR<int>(), colC.DATA_PTR<int>());
auto rowC = at::empty(nnzC, indexA.options());
cusparseXcsr2coo(cusparse_handle, row_ptrC.DATA_PTR<int>(), nnzC, m,
rowC.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
auto indexC = at::stack({rowC, colC}, 0).toType(at::kLong);
return std::make_tuple(indexC, valueC);
}
auto rowptrC = at::empty(M + 1, rowptrA.options());
auto rowptrC_data = rowptrC.DATA_PTR<int>();
cusparseXcsrgemm2Nnz(handle, M, N, K, descr, colA.numel(), rowptrA_data,
colA_data, descr, colB.numel(), rowptrB_data, colB_data,
descr, 0, NULL, NULL, descr, rowptrC_data, &nnzC, info,
buffer);
at::Tensor degree(at::Tensor row, int64_t num_nodes) {
auto zero = at::zeros(num_nodes, row.options());
auto one = at::ones(row.size(0), row.options());
return zero.scatter_add_(0, row, one);
}
auto colC = at::empty(nnzC, colA.options());
auto colC_data = colC.DATA_PTR<int>();
std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
int64_t num_nodes) {
// Assert already coalesced input.
row = degree(row, num_nodes).cumsum(0);
row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
return std::make_tuple(row, col);
}
if (!valueA.has_value() && valueB.has_value())
valueA = at::ones_like(valueB.value());
template <typename scalar_t>
__global__ void spspmm_bw_kernel(
const int64_t *__restrict__ index, scalar_t *__restrict__ value,
const int64_t *__restrict__ rowA, const int64_t *__restrict__ colA,
const scalar_t *__restrict__ valueA, const int64_t *__restrict__ rowB,
const int64_t *__restrict__ colB, const scalar_t *__restrict__ valueB,
const size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t e = idx; e < numel; e += stride) {
int64_t i = index[e], j = index[numel + e];
for (ptrdiff_t dA = rowA[i]; dA < rowA[i + 1]; dA++) {
int64_t cA = colA[dA];
for (ptrdiff_t dB = rowB[j]; dB < rowB[j + 1]; dB++) {
int64_t cB = colB[dB];
if (cA == cB) {
value[e] += valueA[dA] * valueB[dB];
}
if (cB >= cA) {
break;
}
}
}
}
}
if (!valueB.has_value() && valueA.has_value())
valueB = at::ones_like(valueA.value());
at::optional<at::Tensor> valueC = at::nullopt;
if (valueA.has_value())
valueC = at::empty(nnzC, valueA.value().options());
at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
at::Tensor valueA, at::Tensor indexB,
at::Tensor valueB, size_t rowA_max, size_t rowB_max) {
cudaSetDevice(index.get_device());
auto value = at::zeros(index.size(1), valueA.options());
AT_DISPATCH_CUSPARSE_CSR_GEMM2_TYPES(scalar_type, [&] {
scalar_t alpha = (scalar_t)1;
at::Tensor rowA, colA;
std::tie(rowA, colA) = to_csr(indexA[0], indexA[1], rowA_max);
scalar_t *valueA_data = NULL;
if (valueA.has_value())
valueA_data = valueA.value().DATA_PTR<scalar_t>();
at::Tensor rowB, colB;
std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);
scalar_t *valueB_data = NULL;
if (valueB.has_value())
valueB_data = valueB.value().DATA_PTR<scalar_t>();
AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] {
spspmm_bw_kernel<scalar_t><<<BLOCKS(value.numel()), THREADS>>>(
index.DATA_PTR<int64_t>(), value.DATA_PTR<scalar_t>(),
rowA.DATA_PTR<int64_t>(), colA.DATA_PTR<int64_t>(),
valueA.DATA_PTR<scalar_t>(), rowB.DATA_PTR<int64_t>(),
colB.DATA_PTR<int64_t>(), valueB.DATA_PTR<scalar_t>(), value.numel());
scalar_t *valueC_data = NULL;
if (valueC.has_value())
valueC_data = valueC.value().DATA_PTR<scalar_t>();
cusparsecsrgemm2(handle, M, N, K, &alpha, descr, colA.numel(), valueA_data,
rowptrA_data, colA_data, descr, colB.numel(), valueB_data,
rowptrB_data, colB_data, NULL, descr, 0, NULL, NULL, NULL,
descr, valueC_data, rowptrC_data, colC_data, info, buffer);
});
return value;
auto rowC = at::empty_like(colC);
auto rowC_data = rowC.DATA_PTR<int>();
cusparseXcsr2coo(handle, rowptrC_data, nnzC, M, rowC_data,
CUSPARSE_INDEX_BASE_ZERO);
cusparseDestroyCsrgemm2Info(info);
auto indexC = at::stack({rowC.toType(at::kLong), colC.toType(at::kLong)}, 0);
return std::make_tuple(indexC, rowptrC.toType(at::kLong), valueC);
}
// #define THREADS 1024
// #define BLOCKS(N) (N + THREADS - 1) / THREADS
// #define CSRGEMM(TYPE, ...) \
// [&] { \
// const auto &the_type = TYPE; \
// (void)the_type; \
// at::ScalarType _st = ::detail::scalar_type(TYPE); \
// switch (_st) { \
// case at::ScalarType::Float: { \
// using scalar_t = float; \
// return cusparseScsrgemm(__VA_ARGS__); \
// } \
// case at::ScalarType::Double: { \
// using scalar_t = double; \
// return cusparseDcsrgemm(__VA_ARGS__); \
// } \
// default: \
// AT_ERROR("Not implemented for '", toString(_st), "'"); \
// } \
// }()
// static cusparseHandle_t cusparse_handle = 0;
// static void init_cusparse() {
// if (cusparse_handle == 0) {
// cusparseStatus_t status = cusparseCreate(&cusparse_handle);
// }
// }
// std::tuple<at::Tensor, at::Tensor>
// spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
// at::Tensor valueB, size_t m, size_t k, size_t n) {
// cudaSetDevice(indexA.get_device());
// init_cusparse();
// indexA = indexA.contiguous();
// valueA = valueA.contiguous();
// indexB = indexB.contiguous();
// valueB = valueB.contiguous();
// auto nnzA = valueA.size(0);
// auto nnzB = valueB.size(0);
// indexA = indexA.toType(at::kInt);
// indexB = indexB.toType(at::kInt);
// // Convert A to CSR format.
// auto row_ptrA = at::empty(m + 1, indexA.options());
// cusparseXcoo2csr(cusparse_handle, indexA[0].DATA_PTR<int>(), nnzA, k,
// row_ptrA.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
// auto colA = indexA[1];
// cudaMemcpy(row_ptrA.DATA_PTR<int>() + m, &nnzA, sizeof(int),
// cudaMemcpyHostToDevice);
// // Convert B to CSR format.
// auto row_ptrB = at::empty(k + 1, indexB.options());
// cusparseXcoo2csr(cusparse_handle, indexB[0].DATA_PTR<int>(), nnzB, k,
// row_ptrB.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
// auto colB = indexB[1];
// cudaMemcpy(row_ptrB.DATA_PTR<int>() + k, &nnzB, sizeof(int),
// cudaMemcpyHostToDevice);
// cusparseMatDescr_t descr = 0;
// cusparseCreateMatDescr(&descr);
// cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL);
// cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);
// int nnzC;
// auto row_ptrC = at::empty(m + 1, indexB.options());
// cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
// CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr,
// nnzA, row_ptrA.DATA_PTR<int>(),
// colA.DATA_PTR<int>(), descr, nnzB,
// row_ptrB.DATA_PTR<int>(), colB.DATA_PTR<int>(),
// descr, row_ptrC.DATA_PTR<int>(), &nnzC);
// auto colC = at::empty(nnzC, indexA.options());
// auto valueC = at::empty(nnzC, valueA.options());
// CSRGEMM(valueC.scalar_type(), cusparse_handle,
// CUSPARSE_OPERATION_NON_TRANSPOSE,
// CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
// valueA.DATA_PTR<scalar_t>(), row_ptrA.DATA_PTR<int>(),
// colA.DATA_PTR<int>(), descr, nnzB, valueB.DATA_PTR<scalar_t>(),
// row_ptrB.DATA_PTR<int>(), colB.DATA_PTR<int>(), descr,
// valueC.DATA_PTR<scalar_t>(), row_ptrC.DATA_PTR<int>(),
// colC.DATA_PTR<int>());
// auto rowC = at::empty(nnzC, indexA.options());
// cusparseXcsr2coo(cusparse_handle, row_ptrC.DATA_PTR<int>(), nnzC, m,
// rowC.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
// auto indexC = at::stack({rowC, colC}, 0).toType(at::kLong);
// return std::make_tuple(indexC, valueC);
// }
// at::Tensor degree(at::Tensor row, int64_t num_nodes) {
// auto zero = at::zeros(num_nodes, row.options());
// auto one = at::ones(row.size(0), row.options());
// return zero.scatter_add_(0, row, one);
// }
// std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
// int64_t num_nodes) {
// // Assert already coalesced input.
// row = degree(row, num_nodes).cumsum(0);
// row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
// return std::make_tuple(row, col);
// }
// template <typename scalar_t>
// __global__ void spspmm_bw_kernel(
// const int64_t *__restrict__ index, scalar_t *__restrict__ value,
// const int64_t *__restrict__ rowA, const int64_t *__restrict__ colA,
// const scalar_t *__restrict__ valueA, const int64_t *__restrict__
// rowB, const int64_t *__restrict__ colB, const scalar_t *__restrict__
// valueB, const size_t numel) {
// const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
// const size_t stride = blockDim.x * gridDim.x;
// for (ptrdiff_t e = idx; e < numel; e += stride) {
// int64_t i = index[e], j = index[numel + e];
// for (ptrdiff_t dA = rowA[i]; dA < rowA[i + 1]; dA++) {
// int64_t cA = colA[dA];
// for (ptrdiff_t dB = rowB[j]; dB < rowB[j + 1]; dB++) {
// int64_t cB = colB[dB];
// if (cA == cB) {
// value[e] += valueA[dA] * valueB[dB];
// }
// if (cB >= cA) {
// break;
// }
// }
// }
// }
// }
// at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
// at::Tensor valueA, at::Tensor indexB,
// at::Tensor valueB, size_t rowA_max, size_t
// rowB_max) {
// cudaSetDevice(index.get_device());
// auto value = at::zeros(index.size(1), valueA.options());
// at::Tensor rowA, colA;
// std::tie(rowA, colA) = to_csr(indexA[0], indexA[1], rowA_max);
// at::Tensor rowB, colB;
// std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);
// AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] {
// spspmm_bw_kernel<scalar_t><<<BLOCKS(value.numel()), THREADS>>>(
// index.DATA_PTR<int64_t>(), value.DATA_PTR<scalar_t>(),
// rowA.DATA_PTR<int64_t>(), colA.DATA_PTR<int64_t>(),
// valueA.DATA_PTR<scalar_t>(), rowB.DATA_PTR<int64_t>(),
// colB.DATA_PTR<int64_t>(), valueB.DATA_PTR<scalar_t>(),
// value.numel());
// });
// return value;
// }
......@@ -48,3 +48,19 @@ def test_spmm(dtype, device, reduce):
assert torch.allclose(expected, out)
assert torch.allclose(expected_grad_value, value.grad)
assert torch.allclose(expected_grad_other, other.grad)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device):
src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype,
device=device)
src = SparseTensor.from_dense(src)
out = src @ src
assert out.size() == (3, 3)
assert out.has_value()
src.set_value_(None)
out = src @ src
assert out.size() == (3, 3)
assert not out.has_value()
import torch
import scipy.sparse
from torch_sparse import spmm_cpu
from torch_scatter import scatter_add
......@@ -8,6 +8,11 @@ try:
except ImportError:
spmm_cuda = None
try:
from torch_sparse import spspmm_cuda
except ImportError:
spspmm_cuda = None
def spmm(is_cuda):
return spmm_cuda if is_cuda else spmm_cpu
......@@ -61,10 +66,9 @@ class SPMM(torch.autograd.Function):
grad_mat = None
if ctx.needs_input_grad[6]:
if ctx.reduce in ['sum', 'add']:
row = index[0][csr2csc]
value = value[csr2csc] if value is not None else value
grad_mat, _ = spmm(grad_out.is_cuda).spmm(
colptr, row, value, grad_out, 'sum')
colptr, index[0][csr2csc], value, grad_out, 'sum')
elif ctx.reduce == 'mean':
count = rowcount[index[0]].to(mat.dtype).clamp_(min=1)
......@@ -88,9 +92,61 @@ class SPMM(torch.autograd.Function):
return None, None, None, None, None, grad_value, grad_mat, None
class SPSPMM(torch.autograd.Function):
@staticmethod
def forward(ctx, rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K):
if rowptrA.is_cuda:
indexC, rowptrC, valueC = spspmm_cuda.spspmm(
rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K)
else:
dtype = None
if valueA is not None:
dtype = valueA.dtype
if valueB is not None:
dtype = valueB.dtype
if valueA is None:
valueA = torch.ones(colA.numel(), dtype=dtype)
A = scipy.sparse.csr_matrix((valueA, colA, rowptrA), (M, N))
if valueB is None:
valueB = torch.ones(colB.numel(), dtype=dtype)
B = scipy.sparse.csr_matrix((valueB, colB, rowptrB), (N, K))
C = A @ B
valueC = torch.from_numpy(
C.data).to(dtype) if dtype is not None else None
rowptrC = torch.from_numpy(C.indptr).to(torch.int64)
C = C.tocoo()
rowC = torch.from_numpy(C.row).to(torch.int64)
colC = torch.from_numpy(C.col).to(torch.int64)
indexC = torch.stack([rowC, colC], dim=0)
# We cannot return `NoneType` in torch.autograd :(
if valueC is None:
return indexC, rowptrC
else:
return indexC, rowptrC, valueC
@staticmethod
def backward(ctx, grad_indexC, grad_rowptrC, *args):
grad_valueA = None
if ctx.needs_input_grad[2]:
raise NotImplementedError
grad_valueB = None
if ctx.needs_input_grad[5]:
raise NotImplementedError
return (None, None, grad_valueA, None, None, grad_valueB, None, None,
None)
def matmul(src, other, reduce='sum'):
assert src.dim() == 2 and src.size(-1) == other.size(-2)
# Sparse-Dense Matrix Multiplication.
if torch.is_tensor(other):
assert reduce in ['sum', 'add', 'mean', 'min', 'max']
(index, value), rowptr = src.coo(), src.storage.rowptr
......@@ -106,8 +162,16 @@ def matmul(src, other, reduce='sum'):
return SPMM.apply(index, rowcount, rowptr, colptr, csr2csc, value,
other, reduce)
# Sparse-Sparse Matrix Multiplication.
elif isinstance(other, src.__class__):
assert reduce in ['sum', 'add']
raise NotImplementedError
assert src.dim() == 2 and other.dim() == 2
data = SPSPMM.apply(*src.csr(), *other.csr(), src.size(0), src.size(1),
other.size(1))
data = data if len(data) == 3 else data + (None, )
sparse_size = torch.Size([src.size(0), other.size(1)])
out = src.__class__(data[0], data[2], sparse_size, is_sorted=True)
out.storage._rowptr = data[1]
return out
raise ValueError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment