Commit ee8c16a7 authored by rusty1s's avatar rusty1s
Browse files

matmul

parent c5f5be51
#include <torch/script.h> #include "diag_cpu.h"
#include "compat.h" #include "utils.h"
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") torch::Tensor non_diag_mask_cpu(torch::Tensor row, torch::Tensor col, int64_t M,
torch::Tensor non_diag_mask(torch::Tensor row, torch::Tensor col, int64_t M,
int64_t N, int64_t k) { int64_t N, int64_t k) {
CHECK_CPU(row); CHECK_CPU(row);
CHECK_CPU(col); CHECK_CPU(col);
int64_t E = row.size(0); auto E = row.size(0);
int64_t num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k); auto num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k);
auto row_data = row.DATA_PTR<int64_t>(); auto row_data = row.data_ptr<int64_t>();
auto col_data = col.DATA_PTR<int64_t>(); auto col_data = col.data_ptr<int64_t>();
auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool)); auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool));
auto mask_data = mask.DATA_PTR<bool>(); auto mask_data = mask.data_ptr<bool>();
int64_t r, c; int64_t r, c;
if (k < 0) { if (k < 0) {
...@@ -47,6 +45,3 @@ torch::Tensor non_diag_mask(torch::Tensor row, torch::Tensor col, int64_t M, ...@@ -47,6 +45,3 @@ torch::Tensor non_diag_mask(torch::Tensor row, torch::Tensor col, int64_t M,
return mask; return mask;
} }
static auto registry =
torch::RegisterOperators("torch_sparse_cpu::non_diag_mask", &non_diag_mask);
#pragma once
#include <torch/extension.h>
torch::Tensor non_diag_mask_cpu(torch::Tensor row, torch::Tensor col, int64_t M,
int64_t N, int64_t k);
#pragma once
#include <limits>
#include <map>
enum ReductionType { SUM, MEAN, MUL, DIV, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"mean", MEAN}, {"mul", MUL},
{"div", DIV}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MUL: { \
const ReductionType REDUCE = MUL; \
return __VA_ARGS__(); \
} \
case DIV: { \
const ReductionType REDUCE = DIV; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init() {
if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1;
else if (REDUCE == MIN)
return std::numeric_limits<scalar_t>::max();
else if (REDUCE == MAX)
return std::numeric_limits<scalar_t>::lowest();
else
return (scalar_t)0;
}
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
*val = *val * new_val;
else if (REDUCE == DIV)
*val = *val / new_val;
else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
*address = val / (count > 0 ? count : (scalar_t)1);
else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else
*address = (scalar_t)0;
}
}
};
#include <torch/script.h> #include "spmm_cpu.h"
#include "compat.h"
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
enum ReductionType { SUM, MEAN, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
#define AT_DISPATCH_HAS_VAL(value_opt, ...) \
[&] { \
switch (value_opt.has_value()) { \
case true: { \
const bool HAS_VAL = true; \
return __VA_ARGS__(); \
} \
case false: { \
const bool HAS_VAL = false; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init() {
if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
return std::numeric_limits<scalar_t>::lowest();
} else {
return (scalar_t)0;
}
}
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg, #include "reducer.h"
int64_t new_arg) { #include "utils.h"
if (REDUCE == SUM || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) {
if (REDUCE == SUM) {
*address = val;
} else if (REDUCE == MEAN) {
*address = val / (count > 0 ? count : (scalar_t)1);
} else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else {
*address = (scalar_t)0;
}
}
}
};
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm(torch::Tensor rowptr, torch::Tensor col, spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> value_opt, torch::Tensor mat, torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce) { std::string reduce) {
CHECK_CPU(rowptr); CHECK_CPU(rowptr);
CHECK_CPU(col); CHECK_CPU(col);
if (value_opt.has_value()) if (optional_value.has_value())
CHECK_CPU(value_opt.value()); CHECK_CPU(optional_value.value());
CHECK_CPU(mat); CHECK_CPU(mat);
AT_ASSERTM(rowptr.dim() == 1, "Input mismatch"); CHECK_INPUT(rowptr.dim() == 1);
AT_ASSERTM(col.dim() == 1, "Input mismatch"); CHECK_INPUT(col.dim() == 1);
if (value_opt.has_value()) if (optional_value.has_value()) {
AT_ASSERTM(value_opt.value().dim() == 1); CHECK_INPUT(optional_value.value().dim() == 1);
AT_ASSERTM(mat.dim() >= 2, "Input mismatch"); CHECK_INPUT(optional_value.value().size(0) == col.size(0));
}
CHECK_INPUT(mat.dim() >= 2);
mat = mat.contiguous(); mat = mat.contiguous();
...@@ -112,11 +31,11 @@ spmm(torch::Tensor rowptr, torch::Tensor col, ...@@ -112,11 +31,11 @@ spmm(torch::Tensor rowptr, torch::Tensor col,
int64_t *arg_out_data = nullptr; int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, col.numel(), rowptr.options()); arg_out = torch::full_like(out, col.numel(), rowptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>(); arg_out_data = arg_out.value().data_ptr<int64_t>();
} }
auto rowptr_data = rowptr.DATA_PTR<int64_t>(); auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.DATA_PTR<int64_t>(); auto col_data = col.data_ptr<int64_t>();
auto M = rowptr.numel() - 1; auto M = rowptr.numel() - 1;
auto N = mat.size(-2); auto N = mat.size(-2);
...@@ -125,8 +44,8 @@ spmm(torch::Tensor rowptr, torch::Tensor col, ...@@ -125,8 +44,8 @@ spmm(torch::Tensor rowptr, torch::Tensor col,
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm", [&] { AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm", [&] {
scalar_t *value_data = nullptr; scalar_t *value_data = nullptr;
auto mat_data = mat.DATA_PTR<scalar_t>(); auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
scalar_t val; scalar_t val;
std::vector<scalar_t> vals(K); std::vector<scalar_t> vals(K);
...@@ -134,9 +53,9 @@ spmm(torch::Tensor rowptr, torch::Tensor col, ...@@ -134,9 +53,9 @@ spmm(torch::Tensor rowptr, torch::Tensor col,
std::vector<int64_t> args(K); std::vector<int64_t> args(K);
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
AT_DISPATCH_HAS_VAL(value_opt, [&] { AT_DISPATCH_HAS_VALUE(optional_value, [&] {
if (HAS_VAL) { if (HAS_VALUE) {
value_data = value_opt.value().DATA_PTR<scalar_t>(); value_data = optional_value.value().data_ptr<scalar_t>();
} }
for (int b = 0; b < B; b++) { for (int b = 0; b < B; b++) {
...@@ -149,10 +68,10 @@ spmm(torch::Tensor rowptr, torch::Tensor col, ...@@ -149,10 +68,10 @@ spmm(torch::Tensor rowptr, torch::Tensor col,
int offset = b * N * K; int offset = b * N * K;
for (int e = row_start; e < row_end; e++) { for (int e = row_start; e < row_end; e++) {
c = col_data[e]; c = col_data[e];
if (HAS_VAL) if (HAS_VALUE)
val = value_data[e]; val = value_data[e];
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
if (HAS_VAL) if (HAS_VALUE)
Reducer<scalar_t, REDUCE>::update( Reducer<scalar_t, REDUCE>::update(
&vals[k], val * mat_data[offset + c * K + k], &args[k], &vals[k], val * mat_data[offset + c * K + k], &args[k],
e); e);
...@@ -175,7 +94,7 @@ spmm(torch::Tensor rowptr, torch::Tensor col, ...@@ -175,7 +94,7 @@ spmm(torch::Tensor rowptr, torch::Tensor col,
return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
} }
torch::Tensor spmm_val_bw(torch::Tensor row, torch::Tensor rowptr, torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat, torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce) { torch::Tensor grad, std::string reduce) {
CHECK_CPU(row); CHECK_CPU(row);
...@@ -195,13 +114,13 @@ torch::Tensor spmm_val_bw(torch::Tensor row, torch::Tensor rowptr, ...@@ -195,13 +114,13 @@ torch::Tensor spmm_val_bw(torch::Tensor row, torch::Tensor rowptr,
auto out = torch::zeros(row.numel(), grad.options()); auto out = torch::zeros(row.numel(), grad.options());
auto row_data = row.DATA_PTR<int64_t>(); auto row_data = row.data_ptr<int64_t>();
auto rowptr_data = rowptr.DATA_PTR<int64_t>(); auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.DATA_PTR<int64_t>(); auto col_data = col.data_ptr<int64_t>();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw", [&] { AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_value_bw", [&] {
auto mat_data = mat.DATA_PTR<scalar_t>(); auto mat_data = mat.data_ptr<scalar_t>();
auto grad_data = grad.DATA_PTR<scalar_t>(); auto grad_data = grad.data_ptr<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
scalar_t val; scalar_t val;
int64_t row, col; int64_t row, col;
...@@ -225,6 +144,3 @@ torch::Tensor spmm_val_bw(torch::Tensor row, torch::Tensor rowptr, ...@@ -225,6 +144,3 @@ torch::Tensor spmm_val_bw(torch::Tensor row, torch::Tensor rowptr,
return out; return out;
} }
static auto registry = torch::RegisterOperators("torch_sparse_cpu::spmm", &spmm)
.op("torch_sparse_cpu::spmm_val_bw", &spmm_val_bw);
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce);
torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce);
...@@ -4,3 +4,17 @@ ...@@ -4,3 +4,17 @@
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") #define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define AT_DISPATCH_HAS_VALUE(optional_value, ...) \
[&] { \
switch (optional_value.has_value()) { \
case true: { \
const bool HAS_VALUE = true; \
return __VA_ARGS__(); \
} \
case false: { \
const bool HAS_VALUE = false; \
return __VA_ARGS__(); \
} \
} \
}()
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <torch/script.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
int64_t M, int64_t N, int64_t k);
torch::Tensor non_diag_mask(torch::Tensor row, torch::Tensor col, int64_t M,
int64_t N, int64_t k) {
CHECK_CUDA(row);
CHECK_CUDA(col);
return non_diag_mask_cuda(row, col, M, N, k);
}
static auto registry = torch::RegisterOperators(
"torch_sparse_cuda::non_diag_mask", &non_diag_mask);
#include "diag_cuda.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "compat.cuh" #include "utils.cuh"
#define THREADS 1024 #define THREADS 1024
...@@ -40,16 +41,18 @@ __global__ void non_diag_mask_kernel(const int64_t *row_data, ...@@ -40,16 +41,18 @@ __global__ void non_diag_mask_kernel(const int64_t *row_data,
torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col, torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
int64_t M, int64_t N, int64_t k) { int64_t M, int64_t N, int64_t k) {
CHECK_CUDA(row);
CHECK_CUDA(col);
cudaSetDevice(row.get_device()); cudaSetDevice(row.get_device());
int64_t E = row.size(0); auto E = row.size(0);
int64_t num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k); auto num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k);
auto row_data = row.DATA_PTR<int64_t>(); auto row_data = row.data_ptr<int64_t>();
auto col_data = col.DATA_PTR<int64_t>(); auto col_data = col.data_ptr<int64_t>();
auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool)); auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool));
auto mask_data = mask.DATA_PTR<bool>(); auto mask_data = mask.data_ptr<bool>();
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
non_diag_mask_kernel<<<(E + THREADS - 1) / THREADS, THREADS, 0, stream>>>( non_diag_mask_kernel<<<(E + THREADS - 1) / THREADS, THREADS, 0, stream>>>(
......
#pragma once
#include <torch/extension.h>
torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
int64_t M, int64_t N, int64_t k);
#include "spmm_cuda.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce) {
return std::make_tuple(mat, optional_value);
}
torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce) {
return row;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce);
torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce);
#include <torch/script.h>
#include "cpu/diag_cpu.h"
#ifdef WITH_CUDA
#include "cuda/diag_cuda.h"
#endif
torch::Tensor non_diag_mask(torch::Tensor row, torch::Tensor col, int64_t M,
int64_t N, int64_t k) {
if (row.device().is_cuda()) {
#ifdef WITH_CUDA
return non_diag_mask_cuda(row, col, M, N, k);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return non_diag_mask_cpu(row, col, M, N, k);
}
}
static auto registry = torch::RegisterOperators().op(
"torch_sparse::non_diag_mask", &non_diag_mask);
#include <torch/script.h>
#include "cpu/spmm_cpu.h"
#ifdef WITH_CUDA
#include "cuda/spmm_cuda.h"
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_fw(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
return spmm_cuda(rowptr, col, optional_value, mat, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spmm_cpu(rowptr, col, optional_value, mat, reduce);
}
}
torch::Tensor spmm_value_bw(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
return spmm_value_bw_cuda(row, rowptr, col, mat, grad, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spmm_value_bw_cpu(row, rowptr, col, mat, grad, reduce);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SPMMSum : public torch::autograd::Function<SPMMSum> {
public:
static variable_list forward(AutogradContext *ctx,
torch::optional<Variable> optional_row,
Variable rowptr, Variable col, Variable value,
torch::optional<Variable> optional_colptr,
torch::optional<Variable> optional_csr2csc,
Variable mat) {
torch::Tensor row;
if (optional_row.has_value())
row = optional_row.value();
torch::optional<torch::Tensor> optional_value = torch::nullopt;
if (value.numel() > 0)
optional_value = value;
torch::Tensor colptr;
if (optional_colptr.has_value())
colptr = optional_colptr.value();
torch::Tensor csr2csc;
if (optional_csr2csc.has_value())
csr2csc = optional_csr2csc.value();
auto out = std::get<0>(spmm_fw(rowptr, col, optional_value, mat, "sum"));
ctx->save_for_backward({row, rowptr, col, value, colptr, csr2csc, mat});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto row = saved[0];
auto rowptr = saved[1];
auto col = saved[2];
auto value = saved[3];
torch::optional<torch::Tensor> optional_value = torch::nullopt;
if (value.numel() > 0)
optional_value = value;
auto colptr = saved[4];
auto csr2csc = saved[5];
auto mat = saved[6];
auto grad_value = Variable();
if (optional_value.has_value() &&
torch::autograd::any_variable_requires_grad({value})) {
grad_value = spmm_value_bw(row, rowptr, col, mat, grad_out, "sum");
}
auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) {
if (optional_value.has_value())
optional_value = optional_value.value().index_select(0, csr2csc);
grad_mat = torch::zeros_like(mat);
grad_mat = std::get<0>(spmm_fw(colptr, row.index_select(0, csr2csc),
optional_value, grad_out, "sum"));
}
return {Variable(), Variable(), Variable(), grad_value,
Variable(), Variable(), grad_mat};
}
};
torch::Tensor spmm_sum(torch::optional<torch::Tensor> optional_row,
torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_colptr,
torch::optional<torch::Tensor> optional_csr2csc,
torch::Tensor mat) {
// Since we cannot return an *optional* gradient, we need to convert
// `optional_value` to an empty sized tensor first :(
auto value = torch::Tensor();
if (optional_value.has_value())
value = optional_value.value();
return SPMMSum::apply(optional_row, rowptr, col, value, optional_colptr,
optional_csr2csc, mat)[0];
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::spmm_sum", &spmm_sum);
...@@ -10,6 +10,9 @@ import torch_scatter ...@@ -10,6 +10,9 @@ import torch_scatter
from .utils import devices, grad_dtypes from .utils import devices, grad_dtypes
reductions = ['sum', 'mean', 'min', 'max'] reductions = ['sum', 'mean', 'min', 'max']
devices = ['cpu']
grad_dtypes = [torch.float]
reductions = ['sum']
@pytest.mark.parametrize('dtype,device,reduce', @pytest.mark.parametrize('dtype,device,reduce',
...@@ -25,14 +28,14 @@ def test_spmm(dtype, device, reduce): ...@@ -25,14 +28,14 @@ def test_spmm(dtype, device, reduce):
requires_grad=True) requires_grad=True)
src_col = other.index_select(-2, col) * value.unsqueeze(-1) src_col = other.index_select(-2, col) * value.unsqueeze(-1)
func = 'add' if reduce == 'sum' else reduce expected = torch_scatter.scatter(src_col, row, dim=-2, reduce=reduce)
expected = getattr(torch_scatter, f'scatter_{func}')(src_col, row, dim=-2)
expected = expected[0] if isinstance(expected, tuple) else expected
if reduce == 'min': if reduce == 'min':
expected[expected > 1000] = 0 expected[expected > 1000] = 0
if reduce == 'max': if reduce == 'max':
expected[expected < -1000] = 0 expected[expected < -1000] = 0
print(expected)
grad_out = torch.randn_like(expected) grad_out = torch.randn_like(expected)
expected.backward(grad_out) expected.backward(grad_out)
...@@ -42,7 +45,6 @@ def test_spmm(dtype, device, reduce): ...@@ -42,7 +45,6 @@ def test_spmm(dtype, device, reduce):
other.grad = None other.grad = None
out = matmul(src, other, reduce) out = matmul(src, other, reduce)
out = out[0] if isinstance(out, tuple) else out
out.backward(grad_out) out.backward(grad_out)
assert torch.allclose(expected, out) assert torch.allclose(expected, out)
...@@ -50,17 +52,17 @@ def test_spmm(dtype, device, reduce): ...@@ -50,17 +52,17 @@ def test_spmm(dtype, device, reduce):
assert torch.allclose(expected_grad_other, other.grad) assert torch.allclose(expected_grad_other, other.grad)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) # @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device): # def test_spspmm(dtype, device):
src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype, # src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype,
device=device) # device=device)
src = SparseTensor.from_dense(src) # src = SparseTensor.from_dense(src)
out = src @ src # out = src @ src
assert out.size() == (3, 3) # assert out.size() == (3, 3)
assert out.has_value() # assert out.has_value()
src.set_value_(None) # src.set_value_(None)
out = src @ src # out = src @ src
assert out.size() == (3, 3) # assert out.size() == (3, 3)
assert not out.has_value() # assert not out.has_value()
...@@ -46,3 +46,4 @@ from .diag import set_diag, remove_diag ...@@ -46,3 +46,4 @@ from .diag import set_diag, remove_diag
from .add import add, add_, add_nnz, add_nnz_ from .add import add, add_, add_nnz, add_nnz_
from .mul import mul, mul_, mul_nnz, mul_nnz_ from .mul import mul, mul_, mul_nnz, mul_nnz_
from .reduce import sum, mean, min, max from .reduce import sum, mean, min, max
from .matmul import spmm_sum, spmm_add, spmm, matmul
import warnings
import os.path as osp
from typing import Optional from typing import Optional
import torch import torch
from torch_sparse.storage import SparseStorage from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
try:
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_diag.so'))
except OSError:
warnings.warn('Failed to load `diag` binaries.')
def non_diag_mask_placeholder(row: torch.Tensor, col: torch.Tensor, M: int,
N: int, k: int) -> torch.Tensor:
raise ImportError
return row
torch.ops.torch_sparse.non_diag_mask = non_diag_mask_placeholder
@torch.jit.script @torch.jit.script
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor: def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
...@@ -38,13 +53,8 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None, ...@@ -38,13 +53,8 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
src = remove_diag(src, k=0) src = remove_diag(src, k=0)
row, col, value = src.coo() row, col, value = src.coo()
if row.is_cuda: mask = torch.ops.torch_sparse.non_diag_mask(row, col, src.size(0),
mask = torch.ops.torch_sparse_cuda.non_diag_mask( src.size(1), k)
row, col, src.size(0), src.size(1), k)
else:
mask = torch.ops.torch_sparse_cpu.non_diag_mask(
row, col, src.size(0), src.size(1), k)
inv_mask = ~mask inv_mask = ~mask
start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel() start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
......
import warnings
import os.path as osp
from typing import Optional, Union
import torch import torch
import scipy.sparse from torch_sparse.tensor import SparseTensor
from torch_scatter import scatter_add
ext = None try:
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_spmm.so'))
except OSError:
warnings.warn('Failed to load `spmm` binaries.')
def spmm_sum_placeholder(row: Optional[torch.Tensor], rowptr: torch.Tensor,
col: torch.Tensor, value: Optional[torch.Tensor],
colptr: Optional[torch.Tensor],
csr2csc: Optional[torch.Tensor],
mat: torch.Tensor) -> torch.Tensor:
raise ImportError
return mat
class SPMM(torch.autograd.Function): torch.ops.torch_sparse.spmm_sum = spmm_sum_placeholder
@staticmethod
def forward(ctx, row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
reduce):
if mat.is_cuda:
out, arg_out = torch.ops.torch_sparse_cuda.spmm(
rowptr, col, value, mat, reduce)
else:
out, arg_out = torch.ops.torch_sparse_cpu.spmm(
rowptr, col, value, mat, reduce)
ctx.reduce = reduce
ctx.save_for_backward(row, rowptr, col, value, mat, rowcount, colptr,
csr2csc, arg_out)
if reduce == 'min' or reduce == 'max': @torch.jit.script
ctx.mark_non_differentiable(arg_out) def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
return out, arg_out rowptr, col, value = src.csr()
else:
return out
@staticmethod
def backward(ctx, grad_out, *args):
(row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
arg_out) = ctx.saved_tensors
invalid_arg_mask = arg_out_ind = None
if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[3]
or ctx.needs_input_grad[4]):
invalid_arg_mask = arg_out == col.size(0)
arg_out_ind = arg_out.masked_fill(invalid_arg_mask, -1)
grad_value = None
if ctx.needs_input_grad[3]:
if ctx.reduce in ['sum', 'add', 'mean']:
grad_value = ext(grad_out.is_cuda).spmm_val_bw(
row, rowptr, col, mat, grad_out, ctx.reduce)
elif ctx.reduce in ['min', 'max']:
col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
out = mat.gather(-2, col_tmp).mul_(grad_out)
out.masked_fill_(invalid_arg_mask, 0)
grad_value = scatter_add(out.flatten(), arg_out.flatten(),
dim=0, dim_size=value.numel() + 1)
grad_value = grad_value[:-1]
grad_mat = None
if ctx.needs_input_grad[4]:
if ctx.reduce in ['sum', 'add']:
value = value[csr2csc] if value is not None else value
grad_mat, _ = ext(grad_out.is_cuda).spmm(
colptr, row[csr2csc], value, grad_out, 'sum')
elif ctx.reduce == 'mean':
count = rowcount[row].to(mat.dtype).clamp_(min=1)
value = count.pow_(-1) if value is None else value / count
row = row[csr2csc]
value = value[csr2csc] if value is not None else value
grad_mat, _ = ext(grad_out.is_cuda).spmm(
colptr, row, value, grad_out, 'sum')
elif ctx.reduce in ['min', 'max']:
if value is not None:
value = value[arg_out_ind.flatten()].view_as(arg_out)
value = value.mul_(grad_out)
else:
value = grad_out
value.masked_fill_(invalid_arg_mask, 0)
col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
grad_mat = scatter_add(value, col_tmp, dim=-2,
dim_size=mat.size(-2))
return None, None, None, grad_value, grad_mat, None, None, None, None
class SPSPMM(torch.autograd.Function):
@staticmethod
def forward(ctx, rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K):
if rowptrA.is_cuda:
rowptrC, colC, valueC = ext(True).spspmm(rowptrA, colA, valueA,
rowptrB, colB, valueB, M,
N, K)
else:
dtype = None
if valueA is not None:
dtype = valueA.dtype
if valueB is not None:
dtype = valueB.dtype
if valueA is None:
valueA = torch.ones(colA.numel(), dtype=dtype)
A = scipy.sparse.csr_matrix((valueA, colA, rowptrA), (M, N))
if valueB is None: row = src.storage._row
valueB = torch.ones(colB.numel(), dtype=dtype) csr2csc = src.storage._csr2csc
B = scipy.sparse.csr_matrix((valueB, colB, rowptrB), (N, K)) colptr = src.storage._colptr
C = A @ B if value is not None and value.requires_grad:
row = src.storage.row()
rowptrC = torch.from_numpy(C.indptr).to(torch.int64) if other.requires_grad:
colC = torch.from_numpy(C.indices).to(torch.int64) row = src.storage.row()
valueC = torch.from_numpy(C.data) csr2csc = src.storage.csr2csc()
valueC = valueC.to(dtype) if dtype is not None else None colptr = src.storage.colptr()
ctx.mark_non_differentiable(rowptrC, colC) print(row is not None)
print(csr2csc is not None)
print(colptr is not None)
# We cannot return `NoneType` in torch.autograd :( return torch.ops.torch_sparse.spmm_sum(row, rowptr, col, value, colptr,
if valueC is None: csr2csc, other)
return rowptrC, colC
else:
return rowptrC, colC, valueC
@staticmethod
def backward(ctx, grad_indexC, grad_rowptrC, *args):
grad_valueA = None
if ctx.needs_input_grad[2]:
raise NotImplementedError
grad_valueB = None @torch.jit.script
if ctx.needs_input_grad[5]: def spmm_add(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
raise NotImplementedError return spmm_sum(src, other)
return (None, None, grad_valueA, None, None, grad_valueB, None, None,
None)
@torch.jit.script
def spmm(src: SparseTensor, other: torch.Tensor,
reduce: str = "sum") -> torch.Tensor:
if reduce == 'sum' or reduce == 'add':
return spmm_sum(src, other)
else:
raise ValueError
def matmul(src, other, reduce='sum'):
assert src.dim() == 2 and src.size(-1) == other.size(-2)
# Sparse-Dense Matrix Multiplication. def matmul(src: SparseTensor, other: Union[torch.Tensor, SparseTensor],
reduce: str = "sum"):
if torch.is_tensor(other): if torch.is_tensor(other):
assert reduce in ['sum', 'add', 'mean', 'min', 'max'] return spmm(src, other, reduce)
rowptr, col, value = src.csr() else:
raise ValueError
row = None
if reduce in ['sum', 'add', 'mean'] and (src.requires_grad
or other.requires_grad):
row = src.storage.row
rowcount = None
if other.requires_grad and reduce in ['mean']:
rowcount = src.storage.rowcount
csr2csc = colptr = None
if other.requires_grad and reduce in ['sum', 'add', 'mean']:
csr2csc, colptr = src.storage.csr2csc, src.storage.colptr
return SPMM.apply(row, rowptr, col, value, other, rowcount, colptr,
csr2csc, reduce)
# Sparse-Sparse Matrix Multiplication.
elif isinstance(other, src.__class__):
assert reduce in ['sum', 'add']
assert src.dim() == 2 and other.dim() == 2
data = SPSPMM.apply(*src.csr(), *other.csr(), src.size(0), src.size(1),
other.size(1))
(rowptr, col), value = data[:2], data[2] if len(data) == 3 else None
sparse_size = torch.Size([src.size(0), other.size(1)])
return src.__class__(rowptr=rowptr, col=col, value=value,
sparse_size=sparse_size, is_sorted=True)
raise ValueError SparseTensor.spmm = lambda self, other, reduce=None: spmm(self, other, reduce)
SparseTensor.matmul = lambda self, other, reduce=None: matmul(
self, other, reduce)
SparseTensor.__matmul__ = lambda self, other: matmul(self, other, 'sum')
# class SPMM(torch.autograd.Function):
# @staticmethod
# def forward(ctx, row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
# reduce):
# if mat.is_cuda:
# out, arg_out = torch.ops.torch_sparse_cuda.spmm(
# rowptr, col, value, mat, reduce)
# else:
# out, arg_out = torch.ops.torch_sparse_cpu.spmm(
# rowptr, col, value, mat, reduce)
# ctx.reduce = reduce
# ctx.save_for_backward(row, rowptr, col, value, mat, rowcount, colptr,
# csr2csc, arg_out)
# if reduce == 'min' or reduce == 'max':
# ctx.mark_non_differentiable(arg_out)
# return out, arg_out
# else:
# return out
# @staticmethod
# def backward(ctx, grad_out, *args):
# (row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
# arg_out) = ctx.saved_tensors
# invalid_arg_mask = arg_out_ind = None
# if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[3]
# or ctx.needs_input_grad[4]):
# invalid_arg_mask = arg_out == col.size(0)
# arg_out_ind = arg_out.masked_fill(invalid_arg_mask, -1)
# grad_value = None
# if ctx.needs_input_grad[3]:
# if ctx.reduce in ['sum', 'add', 'mean']:
# grad_value = ext(grad_out.is_cuda).spmm_val_bw(
# row, rowptr, col, mat, grad_out, ctx.reduce)
# elif ctx.reduce in ['min', 'max']:
# col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
# out = mat.gather(-2, col_tmp).mul_(grad_out)
# out.masked_fill_(invalid_arg_mask, 0)
# grad_value = scatter_add(out.flatten(), arg_out.flatten(),
# dim=0, dim_size=value.numel() + 1)
# grad_value = grad_value[:-1]
# grad_mat = None
# if ctx.needs_input_grad[4]:
# if ctx.reduce in ['sum', 'add']:
# value = value[csr2csc] if value is not None else value
# grad_mat, _ = ext(grad_out.is_cuda).spmm(
# colptr, row[csr2csc], value, grad_out, 'sum')
# elif ctx.reduce == 'mean':
# count = rowcount[row].to(mat.dtype).clamp_(min=1)
# value = count.pow_(-1) if value is None else value / count
# row = row[csr2csc]
# value = value[csr2csc] if value is not None else value
# grad_mat, _ = ext(grad_out.is_cuda).spmm(
# colptr, row, value, grad_out, 'sum')
# elif ctx.reduce in ['min', 'max']:
# if value is not None:
# value = value[arg_out_ind.flatten()].view_as(arg_out)
# value = value.mul_(grad_out)
# else:
# value = grad_out
# value.masked_fill_(invalid_arg_mask, 0)
# col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
# grad_mat = scatter_add(value, col_tmp, dim=-2,
# dim_size=mat.size(-2))
# return None, None, None, grad_value, grad_mat, None, None, None, None
# class SPSPMM(torch.autograd.Function):
# @staticmethod
# def forward(ctx, rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K):
# if rowptrA.is_cuda:
# rowptrC, colC, valueC = ext(True).spspmm(rowptrA, colA, valueA,
# rowptrB, colB, valueB, M,
# N, K)
# else:
# dtype = None
# if valueA is not None:
# dtype = valueA.dtype
# if valueB is not None:
# dtype = valueB.dtype
# if valueA is None:
# valueA = torch.ones(colA.numel(), dtype=dtype)
# A = scipy.sparse.csr_matrix((valueA, colA, rowptrA), (M, N))
# if valueB is None:
# valueB = torch.ones(colB.numel(), dtype=dtype)
# B = scipy.sparse.csr_matrix((valueB, colB, rowptrB), (N, K))
# C = A @ B
# rowptrC = torch.from_numpy(C.indptr).to(torch.int64)
# colC = torch.from_numpy(C.indices).to(torch.int64)
# valueC = torch.from_numpy(C.data)
# valueC = valueC.to(dtype) if dtype is not None else None
# ctx.mark_non_differentiable(rowptrC, colC)
# # We cannot return `NoneType` in torch.autograd :(
# if valueC is None:
# return rowptrC, colC
# else:
# return rowptrC, colC, valueC
# @staticmethod
# def backward(ctx, grad_indexC, grad_rowptrC, *args):
# grad_valueA = None
# if ctx.needs_input_grad[2]:
# raise NotImplementedError
# grad_valueB = None
# if ctx.needs_input_grad[5]:
# raise NotImplementedError
# return (None, None, grad_valueA, None, None, grad_valueB, None, None,
# None)
# def matmul(src, other, reduce='sum'):
# assert src.dim() == 2 and src.size(-1) == other.size(-2)
# # Sparse-Dense Matrix Multiplication.
# if torch.is_tensor(other):
# assert reduce in ['sum', 'add', 'mean', 'min', 'max']
# rowptr, col, value = src.csr()
# row = None
# if reduce in ['sum', 'add', 'mean'] and (src.requires_grad
# or other.requires_grad):
# row = src.storage.row
# rowcount = None
# if other.requires_grad and reduce in ['mean']:
# rowcount = src.storage.rowcount
# csr2csc = colptr = None
# if other.requires_grad and reduce in ['sum', 'add', 'mean']:
# csr2csc, colptr = src.storage.csr2csc, src.storage.colptr
# return SPMM.apply(row, rowptr, col, value, other, rowcount, colptr,
# csr2csc, reduce)
# # Sparse-Sparse Matrix Multiplication.
# elif isinstance(other, src.__class__):
# assert reduce in ['sum', 'add']
# assert src.dim() == 2 and other.dim() == 2
# data = SPSPMM.apply(*src.csr(), *other.csr(), src.size(0), src.size(1),
# other.size(1))
# (rowptr, col), value = data[:2], data[2] if len(data) == 3 else None
# sparse_size = torch.Size([src.size(0), other.size(1)])
# return src.__class__(rowptr=rowptr, col=col, value=value,
# sparse_size=sparse_size, is_sorted=True)
# raise ValueError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment