Commit c4484dbb authored by rusty1s's avatar rusty1s
Browse files

jit support

parent 6e87043a
......@@ -15,7 +15,7 @@ torch::Tensor non_diag_mask(torch::Tensor row, torch::Tensor col, int64_t M,
auto row_data = row.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
auto mask = torch::zeros(E + num_diag, row.options().dtype(at::kBool));
auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool));
auto mask_data = mask.DATA_PTR<bool>();
int64_t r, c;
......
#include <torch/extension.h>
#include <torch/script.h>
#include "compat.h"
......@@ -85,9 +85,10 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
};
std::tuple<at::Tensor, at::optional<at::Tensor>>
spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
at::Tensor mat, std::string reduce) {
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> value_opt, torch::Tensor mat,
std::string reduce) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
......@@ -105,12 +106,12 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
auto sizes = mat.sizes().vec();
sizes[mat.dim() - 2] = rowptr.numel() - 1;
auto out = at::empty(sizes, mat.options());
auto out = torch::empty(sizes, mat.options());
at::optional<at::Tensor> arg_out = at::nullopt;
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, col.numel(), rowptr.options());
arg_out = torch::full_like(out, col.numel(), rowptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
......@@ -174,8 +175,9 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return std::make_tuple(out, arg_out);
}
at::Tensor spmm_val_bw(at::Tensor row, at::Tensor rowptr, at::Tensor col,
at::Tensor mat, at::Tensor grad, std::string reduce) {
torch::Tensor spmm_val_bw(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce) {
CHECK_CPU(row);
CHECK_CPU(rowptr);
CHECK_CPU(col);
......@@ -191,7 +193,7 @@ at::Tensor spmm_val_bw(at::Tensor row, at::Tensor rowptr, at::Tensor col,
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto out = at::zeros(row.numel(), grad.options());
auto out = torch::zeros(row.numel(), grad.options());
auto row_data = row.DATA_PTR<int64_t>();
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
......@@ -224,8 +226,5 @@ at::Tensor spmm_val_bw(at::Tensor row, at::Tensor rowptr, at::Tensor col,
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spmm", &spmm, "Sparse-Dense Matrix Multiplication (CPU)");
m.def("spmm_val_bw", &spmm_val_bw,
"Sparse-Dense Matrix Multiplication Value Backward (CPU)");
}
static auto registry = torch::RegisterOperators("torch_sparse_cpu::spmm", &spmm)
.op("torch_sparse_cpu::spmm_val_bw", &spmm_val_bw);
......@@ -46,7 +46,7 @@ torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
auto row_data = row.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
auto mask = torch::zeros(E + num_diag, row.options().dtype(at::kBool));
auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool));
auto mask_data = mask.DATA_PTR<bool>();
auto stream = at::cuda::getCurrentCUDAStream();
......
#include <torch/extension.h>
#include <torch/script.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
std::tuple<at::Tensor, at::optional<at::Tensor>>
spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
at::Tensor mat, std::string reduce);
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> value_opt, torch::Tensor mat,
std::string reduce);
at::Tensor spmm_val_bw_cuda(at::Tensor row, at::Tensor rowptr, at::Tensor col,
at::Tensor mat, at::Tensor grad,
std::string reduce);
torch::Tensor spmm_val_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce);
std::tuple<at::Tensor, at::optional<at::Tensor>>
spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
at::Tensor mat, std::string reduce) {
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> value_opt, torch::Tensor mat,
std::string reduce) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
if (value_opt.has_value())
......@@ -21,8 +23,9 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return spmm_cuda(rowptr, col, value_opt, mat, reduce);
}
at::Tensor spmm_val_bw(at::Tensor row, at::Tensor rowptr, at::Tensor col,
at::Tensor mat, at::Tensor grad, std::string reduce) {
torch::Tensor spmm_val_bw(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce) {
CHECK_CUDA(row);
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
......@@ -31,8 +34,6 @@ at::Tensor spmm_val_bw(at::Tensor row, at::Tensor rowptr, at::Tensor col,
return spmm_val_bw_cuda(row, rowptr, col, mat, grad, reduce);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spmm", &spmm, "Sparse Matrix Multiplication (CUDA)");
m.def("spmm_val_bw", &spmm_val_bw,
"Sparse-Dense Matrix Multiplication Value Backward (CPU)");
}
static auto registry =
torch::RegisterOperators("torch_sparse_cuda::spmm", &spmm)
.op("torch_sparse_cuda::spmm_val_bw", &spmm_val_bw);
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "compat.cuh"
......@@ -155,9 +155,10 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
}
}
std::tuple<at::Tensor, at::optional<at::Tensor>>
spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
at::Tensor mat, std::string reduce) {
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> value_opt, torch::Tensor mat,
std::string reduce) {
AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
AT_ASSERTM(col.dim() == 1, "Input mismatch");
......@@ -169,12 +170,12 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
auto sizes = mat.sizes().vec();
sizes[mat.dim() - 2] = rowptr.numel() - 1;
auto out = at::empty(sizes, mat.options());
auto out = torch::empty(sizes, mat.options());
at::optional<at::Tensor> arg_out = at::nullopt;
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, col.numel(), rowptr.options());
arg_out = torch::full_like(out, col.numel(), rowptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
......@@ -247,9 +248,9 @@ spmm_val_bw_kernel(const int64_t *row_data, const int64_t *rowptr_data,
}
}
at::Tensor spmm_val_bw_cuda(at::Tensor row, at::Tensor rowptr, at::Tensor col,
at::Tensor mat, at::Tensor grad,
std::string reduce) {
torch::Tensor spmm_val_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce) {
mat = mat.contiguous();
grad = grad.contiguous();
......@@ -261,7 +262,7 @@ at::Tensor spmm_val_bw_cuda(at::Tensor row, at::Tensor rowptr, at::Tensor col,
auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((E * 32 + THREADS - 1) / THREADS);
auto out = at::zeros(row.numel(), grad.options());
auto out = torch::zeros(row.numel(), grad.options());
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw_kernel", [&] {
......
#include <torch/extension.h>
#include <torch/script.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
std::tuple<at::Tensor, at::Tensor, at::optional<at::Tensor>>
spspmm_cuda(at::Tensor rowptrA, at::Tensor colA,
at::optional<at::Tensor> valueA, at::Tensor rowptrB,
at::Tensor colB, at::optional<at::Tensor> valueB, int M, int N,
int K);
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> valueA, torch::Tensor rowptrB,
torch::Tensor colB, torch::optional<torch::Tensor> valueB,
int64_t M, int64_t N, int64_t K);
std::tuple<at::Tensor, at::Tensor, at::optional<at::Tensor>>
spspmm(at::Tensor rowptrA, at::Tensor colA, at::optional<at::Tensor> valueA,
at::Tensor rowptrB, at::Tensor colB, at::optional<at::Tensor> valueB,
int M, int N, int K) {
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> valueA, torch::Tensor rowptrB,
torch::Tensor colB, torch::optional<torch::Tensor> valueB, int64_t M,
int64_t N, int64_t K) {
CHECK_CUDA(rowptrA);
CHECK_CUDA(colA);
if (valueA.has_value())
......@@ -23,40 +24,5 @@ spspmm(at::Tensor rowptrA, at::Tensor colA, at::optional<at::Tensor> valueA,
return spspmm_cuda(rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K);
}
// std::tuple<at::Tensor, at::Tensor>
// spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
// at::Tensor valueB, size_t m, size_t k, size_t n);
// at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
// at::Tensor valueA, at::Tensor indexB,
// at::Tensor valueB, size_t rowA_max, size_t
// rowB_max);
// std::tuple<at::Tensor, at::Tensor> spspmm(at::Tensor indexA, at::Tensor
// valueA,
// at::Tensor indexB, at::Tensor
// valueB, size_t m, size_t k, size_t
// n) {
// CHECK_CUDA(indexA);
// CHECK_CUDA(valueA);
// CHECK_CUDA(indexB);
// CHECK_CUDA(valueB);
// return spspmm_cuda(indexA, valueA, indexB, valueB, m, k, n);
// }
// at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
// at::Tensor indexB, at::Tensor valueB, size_t rowA_max,
// size_t rowB_max) {
// CHECK_CUDA(index);
// CHECK_CUDA(indexA);
// CHECK_CUDA(valueA);
// CHECK_CUDA(indexB);
// CHECK_CUDA(valueB);
// return spspmm_bw_cuda(index, indexA, valueA, indexB, valueB, rowA_max,
// rowB_max);
// }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spspmm", &spspmm, "Sparse-Sparse Matrix Multiplication (CUDA)");
// m.def("spspmm_bw", &spspmm_bw,
// "Sparse-Sparse Matrix Multiplication Backward (CUDA)");
}
static auto registry =
torch::RegisterOperators("torch_sparse_cuda::spspmm", &spspmm);
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <cusparse.h>
......@@ -8,13 +8,13 @@
#define AT_DISPATCH_CUSPARSE_CSR_GEMM2_BUFFER_SIZE_EXT_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case at::ScalarType::Float: { \
case torch::ScalarType::Float: { \
using scalar_t = float; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseScsrgemm2_bufferSizeExt; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Double: { \
case torch::ScalarType::Double: { \
using scalar_t = double; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseDcsrgemm2_bufferSizeExt; \
......@@ -28,12 +28,12 @@
#define AT_DISPATCH_CUSPARSE_CSR_GEMM2_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case at::ScalarType::Float: { \
case torch::ScalarType::Float: { \
using scalar_t = float; \
const auto &cusparsecsrgemm2 = cusparseScsrgemm2; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Double: { \
case torch::ScalarType::Double: { \
using scalar_t = double; \
const auto &cusparsecsrgemm2 = cusparseDcsrgemm2; \
return __VA_ARGS__(); \
......@@ -43,17 +43,17 @@
} \
}()
std::tuple<at::Tensor, at::Tensor, at::optional<at::Tensor>>
spspmm_cuda(at::Tensor rowptrA, at::Tensor colA,
at::optional<at::Tensor> valueA, at::Tensor rowptrB,
at::Tensor colB, at::optional<at::Tensor> valueB, int M, int N,
int K) {
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> valueA, torch::Tensor rowptrB,
torch::Tensor colB, torch::optional<torch::Tensor> valueB,
int64_t M, int64_t N, int64_t K) {
cusparseMatDescr_t descr = 0;
cusparseCreateMatDescr(&descr);
auto handle = at::cuda::getCurrentCUDASparseHandle();
rowptrA = rowptrA.toType(at::kInt), colA = colA.toType(at::kInt);
rowptrB = rowptrB.toType(at::kInt), colB = colB.toType(at::kInt);
rowptrA = rowptrA.toType(torch::kInt), colA = colA.toType(torch::kInt);
rowptrB = rowptrB.toType(torch::kInt), colB = colB.toType(torch::kInt);
auto rowptrA_data = rowptrA.DATA_PTR<int>(), colA_data = colA.DATA_PTR<int>();
auto rowptrB_data = rowptrB.DATA_PTR<int>(), colB_data = colB.DATA_PTR<int>();
......@@ -61,7 +61,7 @@ spspmm_cuda(at::Tensor rowptrA, at::Tensor colA,
csrgemm2Info_t info = NULL;
cusparseCreateCsrgemm2Info(&info);
auto scalar_type = at::ScalarType::Float;
auto scalar_type = torch::ScalarType::Float;
if (valueA.has_value())
scalar_type = valueA.value().scalar_type();
if (valueB.has_value())
......@@ -80,25 +80,25 @@ spspmm_cuda(at::Tensor rowptrA, at::Tensor colA,
cudaMalloc(&buffer, bufferSize);
int nnzC;
auto rowptrC = at::empty(M + 1, rowptrA.options());
auto rowptrC = torch::empty(M + 1, rowptrA.options());
auto rowptrC_data = rowptrC.DATA_PTR<int>();
cusparseXcsrgemm2Nnz(handle, M, N, K, descr, colA.numel(), rowptrA_data,
colA_data, descr, colB.numel(), rowptrB_data, colB_data,
descr, 0, NULL, NULL, descr, rowptrC_data, &nnzC, info,
buffer);
auto colC = at::empty(nnzC, colA.options());
auto colC = torch::empty(nnzC, colA.options());
auto colC_data = colC.DATA_PTR<int>();
if (!valueA.has_value() && valueB.has_value())
valueA = at::ones_like(valueB.value());
valueA = torch::ones_like(valueB.value());
if (!valueB.has_value() && valueA.has_value())
valueB = at::ones_like(valueA.value());
valueB = torch::ones_like(valueA.value());
at::optional<at::Tensor> valueC = at::nullopt;
torch::optional<torch::Tensor> valueC = torch::nullopt;
if (valueA.has_value())
valueC = at::empty(nnzC, valueA.value().options());
valueC = torch::empty(nnzC, valueA.value().options());
AT_DISPATCH_CUSPARSE_CSR_GEMM2_TYPES(scalar_type, [&] {
scalar_t alpha = (scalar_t)1;
......@@ -121,174 +121,8 @@ spspmm_cuda(at::Tensor rowptrA, at::Tensor colA,
descr, valueC_data, rowptrC_data, colC_data, info, buffer);
});
rowptrC = rowptrC.toType(at::kLong);
colC = colC.toType(at::kLong);
rowptrC = rowptrC.toType(torch::kLong);
colC = colC.toType(torch::kLong);
return std::make_tuple(rowptrC, colC, valueC);
}
// #define THREADS 1024
// #define BLOCKS(N) (N + THREADS - 1) / THREADS
// #define CSRGEMM(TYPE, ...) \
// [&] { \
// const auto &the_type = TYPE; \
// (void)the_type; \
// at::ScalarType _st = ::detail::scalar_type(TYPE); \
// switch (_st) { \
// case at::ScalarType::Float: { \
// using scalar_t = float; \
// return cusparseScsrgemm(__VA_ARGS__); \
// } \
// case at::ScalarType::Double: { \
// using scalar_t = double; \
// return cusparseDcsrgemm(__VA_ARGS__); \
// } \
// default: \
// AT_ERROR("Not implemented for '", toString(_st), "'"); \
// } \
// }()
// static cusparseHandle_t cusparse_handle = 0;
// static void init_cusparse() {
// if (cusparse_handle == 0) {
// cusparseStatus_t status = cusparseCreate(&cusparse_handle);
// }
// }
// std::tuple<at::Tensor, at::Tensor>
// spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
// at::Tensor valueB, size_t m, size_t k, size_t n) {
// cudaSetDevice(indexA.get_device());
// init_cusparse();
// indexA = indexA.contiguous();
// valueA = valueA.contiguous();
// indexB = indexB.contiguous();
// valueB = valueB.contiguous();
// auto nnzA = valueA.size(0);
// auto nnzB = valueB.size(0);
// indexA = indexA.toType(at::kInt);
// indexB = indexB.toType(at::kInt);
// // Convert A to CSR format.
// auto row_ptrA = at::empty(m + 1, indexA.options());
// cusparseXcoo2csr(cusparse_handle, indexA[0].DATA_PTR<int>(), nnzA, k,
// row_ptrA.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
// auto colA = indexA[1];
// cudaMemcpy(row_ptrA.DATA_PTR<int>() + m, &nnzA, sizeof(int),
// cudaMemcpyHostToDevice);
// // Convert B to CSR format.
// auto row_ptrB = at::empty(k + 1, indexB.options());
// cusparseXcoo2csr(cusparse_handle, indexB[0].DATA_PTR<int>(), nnzB, k,
// row_ptrB.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
// auto colB = indexB[1];
// cudaMemcpy(row_ptrB.DATA_PTR<int>() + k, &nnzB, sizeof(int),
// cudaMemcpyHostToDevice);
// cusparseMatDescr_t descr = 0;
// cusparseCreateMatDescr(&descr);
// cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL);
// cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);
// int nnzC;
// auto row_ptrC = at::empty(m + 1, indexB.options());
// cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
// CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr,
// nnzA, row_ptrA.DATA_PTR<int>(),
// colA.DATA_PTR<int>(), descr, nnzB,
// row_ptrB.DATA_PTR<int>(), colB.DATA_PTR<int>(),
// descr, row_ptrC.DATA_PTR<int>(), &nnzC);
// auto colC = at::empty(nnzC, indexA.options());
// auto valueC = at::empty(nnzC, valueA.options());
// CSRGEMM(valueC.scalar_type(), cusparse_handle,
// CUSPARSE_OPERATION_NON_TRANSPOSE,
// CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
// valueA.DATA_PTR<scalar_t>(), row_ptrA.DATA_PTR<int>(),
// colA.DATA_PTR<int>(), descr, nnzB, valueB.DATA_PTR<scalar_t>(),
// row_ptrB.DATA_PTR<int>(), colB.DATA_PTR<int>(), descr,
// valueC.DATA_PTR<scalar_t>(), row_ptrC.DATA_PTR<int>(),
// colC.DATA_PTR<int>());
// auto rowC = at::empty(nnzC, indexA.options());
// cusparseXcsr2coo(cusparse_handle, row_ptrC.DATA_PTR<int>(), nnzC, m,
// rowC.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
// auto indexC = at::stack({rowC, colC}, 0).toType(at::kLong);
// return std::make_tuple(indexC, valueC);
// }
// at::Tensor degree(at::Tensor row, int64_t num_nodes) {
// auto zero = at::zeros(num_nodes, row.options());
// auto one = at::ones(row.size(0), row.options());
// return zero.scatter_add_(0, row, one);
// }
// std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
// int64_t num_nodes) {
// // Assert already coalesced input.
// row = degree(row, num_nodes).cumsum(0);
// row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
// return std::make_tuple(row, col);
// }
// template <typename scalar_t>
// __global__ void spspmm_bw_kernel(
// const int64_t *__restrict__ index, scalar_t *__restrict__ value,
// const int64_t *__restrict__ rowA, const int64_t *__restrict__ colA,
// const scalar_t *__restrict__ valueA, const int64_t *__restrict__
// rowB, const int64_t *__restrict__ colB, const scalar_t *__restrict__
// valueB, const size_t numel) {
// const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
// const size_t stride = blockDim.x * gridDim.x;
// for (ptrdiff_t e = idx; e < numel; e += stride) {
// int64_t i = index[e], j = index[numel + e];
// for (ptrdiff_t dA = rowA[i]; dA < rowA[i + 1]; dA++) {
// int64_t cA = colA[dA];
// for (ptrdiff_t dB = rowB[j]; dB < rowB[j + 1]; dB++) {
// int64_t cB = colB[dB];
// if (cA == cB) {
// value[e] += valueA[dA] * valueB[dB];
// }
// if (cB >= cA) {
// break;
// }
// }
// }
// }
// }
// at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
// at::Tensor valueA, at::Tensor indexB,
// at::Tensor valueB, size_t rowA_max, size_t
// rowB_max) {
// cudaSetDevice(index.get_device());
// auto value = at::zeros(index.size(1), valueA.options());
// at::Tensor rowA, colA;
// std::tie(rowA, colA) = to_csr(indexA[0], indexA[1], rowA_max);
// at::Tensor rowB, colB;
// std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);
// AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] {
// spspmm_bw_kernel<scalar_t><<<BLOCKS(value.numel()), THREADS>>>(
// index.DATA_PTR<int64_t>(), value.DATA_PTR<scalar_t>(),
// rowA.DATA_PTR<int64_t>(), colA.DATA_PTR<int64_t>(),
// valueA.DATA_PTR<scalar_t>(), rowB.DATA_PTR<int64_t>(),
// colB.DATA_PTR<int64_t>(), valueB.DATA_PTR<scalar_t>(),
// value.numel());
// });
// return value;
// }
import torch
from torch_scatter import gather_csr
def is_scalar(other):
return isinstance(other, int) or isinstance(other, float)
from torch_sparse.utils import is_scalar
def sparse_add(matA, matB):
......
import torch
from .utils import ext
from torch_sparse.utils import ext
def remove_diag(src, k=0):
......
import torch
import scipy.sparse
from torch_sparse import spmm_cpu
from torch_scatter import scatter_add
try:
from torch_sparse import spmm_cuda
except ImportError:
spmm_cuda = None
try:
from torch_sparse import spspmm_cuda
except ImportError:
spspmm_cuda = None
def spmm(is_cuda):
return spmm_cuda if is_cuda else spmm_cpu
from torch_sparse.utils import ext
class SPMM(torch.autograd.Function):
@staticmethod
def forward(ctx, row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
reduce):
out, arg_out = spmm(mat.is_cuda).spmm(rowptr, col, value, mat, reduce)
out, arg_out = ext(mat.is_cuda).spmm(rowptr, col, value, mat, reduce)
ctx.reduce = reduce
ctx.save_for_backward(row, rowptr, col, value, mat, rowcount, colptr,
......@@ -48,7 +34,7 @@ class SPMM(torch.autograd.Function):
grad_value = None
if ctx.needs_input_grad[3]:
if ctx.reduce in ['sum', 'add', 'mean']:
grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
grad_value = ext(grad_out.is_cuda).spmm_val_bw(
row, rowptr, col, mat, grad_out, ctx.reduce)
elif ctx.reduce in ['min', 'max']:
......@@ -63,7 +49,7 @@ class SPMM(torch.autograd.Function):
if ctx.needs_input_grad[4]:
if ctx.reduce in ['sum', 'add']:
value = value[csr2csc] if value is not None else value
grad_mat, _ = spmm(grad_out.is_cuda).spmm(
grad_mat, _ = ext(grad_out.is_cuda).spmm(
colptr, row[csr2csc], value, grad_out, 'sum')
elif ctx.reduce == 'mean':
......@@ -71,7 +57,7 @@ class SPMM(torch.autograd.Function):
value = count.pow_(-1) if value is None else value / count
row = row[csr2csc]
value = value[csr2csc] if value is not None else value
grad_mat, _ = spmm(grad_out.is_cuda).spmm(
grad_mat, _ = ext(grad_out.is_cuda).spmm(
colptr, row, value, grad_out, 'sum')
elif ctx.reduce in ['min', 'max']:
......@@ -92,9 +78,9 @@ class SPSPMM(torch.autograd.Function):
@staticmethod
def forward(ctx, rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K):
if rowptrA.is_cuda:
rowptrC, colC, valueC = spspmm_cuda.spspmm(rowptrA, colA, valueA,
rowptrB, colB, valueB,
M, N, K)
rowptrC, colC, valueC = ext(True).spspmm(rowptrA, colA, valueA,
rowptrB, colB, valueB, M,
N, K)
else:
dtype = None
if valueA is not None:
......@@ -149,7 +135,7 @@ def matmul(src, other, reduce='sum'):
row = None
if reduce in ['sum', 'add', 'mean'] and (src.requires_grad
or other.reuqires_grad):
or other.requires_grad):
row = src.storage.row
rowcount = None
......
import torch
from torch_scatter import gather_csr
def is_scalar(other):
return isinstance(other, int) or isinstance(other, float)
from torch_sparse.utils import is_scalar
def mul(src, other):
......
......@@ -2,7 +2,7 @@ import warnings
import torch
from torch_scatter import segment_csr, scatter_add
from .utils import ext
from torch_sparse.utils import ext
__cache__ = {'enabled': True}
......
......@@ -15,6 +15,7 @@ from torch_sparse.diag import remove_diag, set_diag
from torch_sparse.matmul import matmul
from torch_sparse.add import add, add_, add_nnz, add_nnz_
from torch_sparse.mul import mul, mul_, mul_nnz, mul_nnz_
from torch_sparse.utils import is_scalar
class SparseTensor(object):
......@@ -501,10 +502,14 @@ TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR < 1) or (TORCH_MAJOR == 1 and TORCH_MINOR < 4):
def add(self, other):
return self.add(other) if torch.is_tensor(other) else NotImplemented
if torch.is_tensor(other) or is_scalar(other):
return self.add(other)
return NotImplemented
def mul(self, other):
return self.mul(other) if torch.is_tensor(other) else NotImplemented
if torch.is_tensor(other) or is_scalar(other):
return self.mul(other)
return NotImplemented
torch.Tensor.__add__ = add
torch.Tensor.__mul__ = add
torch.Tensor.__mul__ = mul
......@@ -2,10 +2,13 @@ import torch
torch.ops.load_library('torch_sparse/convert_cpu.so')
torch.ops.load_library('torch_sparse/diag_cpu.so')
torch.ops.load_library('torch_sparse/spmm_cpu.so')
try:
torch.ops.load_library('torch_sparse/convert_cuda.so')
torch.ops.load_library('torch_sparse/diag_cuda.so')
torch.ops.load_library('torch_sparse/spmm_cuda.so')
torch.ops.load_library('torch_sparse/spspmm_cuda.so')
except OSError as e:
if torch.cuda.is_available():
raise e
......@@ -14,3 +17,7 @@ except OSError as e:
def ext(is_cuda):
name = 'torch_sparse_cuda' if is_cuda else 'torch_sparse_cpu'
return getattr(torch.ops, name)
def is_scalar(other):
return isinstance(other, int) or isinstance(other, float)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment