Commit e44a639f authored by rusty1s's avatar rusty1s
Browse files

spmm done

parent bb1ba6b0
...@@ -168,6 +168,122 @@ public: ...@@ -168,6 +168,122 @@ public:
} }
}; };
class SPMMMin : public torch::autograd::Function<SPMMMin> {
public:
static variable_list forward(AutogradContext *ctx, Variable rowptr,
Variable col, Variable value, Variable mat,
bool has_value) {
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (has_value)
opt_value = value;
auto result = spmm_fw(rowptr, col, opt_value, mat, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->saved_data["has_value"] = has_value;
ctx->save_for_backward({col, value, mat, arg_out});
ctx->mark_non_differentiable({arg_out});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto has_value = ctx->saved_data["has_value"].toBool();
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto col = saved[0], value = saved[1], mat = saved[2], arg_out = saved[3];
auto invalid_arg_mask = arg_out == col.size(0);
arg_out = arg_out.masked_fill(invalid_arg_mask, 0);
auto grad_value = Variable();
if (has_value > 0 && torch::autograd::any_variable_requires_grad({value})) {
auto ind = col.index_select(0, arg_out.flatten()).view_as(arg_out);
auto out = mat.gather(-2, ind);
out.mul_(grad_out);
out.masked_fill_(invalid_arg_mask, 0);
grad_value = torch::zeros_like(value);
grad_value.scatter_add_(0, arg_out.flatten(), out.flatten());
}
auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) {
if (has_value > 0) {
value = value.index_select(0, arg_out.flatten()).view_as(arg_out);
value.mul_(grad_out);
} else
value = grad_out;
value.masked_fill_(invalid_arg_mask, 0);
auto ind = col.index_select(0, arg_out.flatten()).view_as(arg_out);
grad_mat = torch::zeros_like(mat);
grad_mat.scatter_add_(-2, ind, value);
}
return {Variable(), Variable(), grad_value, grad_mat, Variable()};
}
};
class SPMMMax : public torch::autograd::Function<SPMMMax> {
public:
static variable_list forward(AutogradContext *ctx, Variable rowptr,
Variable col, Variable value, Variable mat,
bool has_value) {
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (has_value)
opt_value = value;
auto result = spmm_fw(rowptr, col, opt_value, mat, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->saved_data["has_value"] = has_value;
ctx->save_for_backward({col, value, mat, arg_out});
ctx->mark_non_differentiable({arg_out});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto has_value = ctx->saved_data["has_value"].toBool();
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto col = saved[0], value = saved[1], mat = saved[2], arg_out = saved[3];
auto invalid_arg_mask = arg_out == col.size(0);
arg_out = arg_out.masked_fill(invalid_arg_mask, 0);
auto grad_value = Variable();
if (has_value > 0 && torch::autograd::any_variable_requires_grad({value})) {
auto ind = col.index_select(0, arg_out.flatten()).view_as(arg_out);
auto out = mat.gather(-2, ind);
out.mul_(grad_out);
out.masked_fill_(invalid_arg_mask, 0);
grad_value = torch::zeros_like(value);
grad_value.scatter_add_(0, arg_out.flatten(), out.flatten());
}
auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) {
if (has_value > 0) {
value = value.index_select(0, arg_out.flatten()).view_as(arg_out);
value.mul_(grad_out);
} else
value = grad_out;
value.masked_fill_(invalid_arg_mask, 0);
auto ind = col.index_select(0, arg_out.flatten()).view_as(arg_out);
grad_mat = torch::zeros_like(mat);
grad_mat.scatter_add_(-2, ind, value);
}
return {Variable(), Variable(), grad_value, grad_mat, Variable()};
}
};
torch::Tensor spmm_sum(torch::optional<torch::Tensor> opt_row, torch::Tensor spmm_sum(torch::optional<torch::Tensor> opt_row,
torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value, torch::optional<torch::Tensor> opt_value,
...@@ -191,6 +307,24 @@ torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row, ...@@ -191,6 +307,24 @@ torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row,
opt_csr2csc, mat, opt_value.has_value())[0]; opt_csr2csc, mat, opt_value.has_value())[0];
} }
std::tuple<torch::Tensor, torch::Tensor>
spmm_min(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value, torch::Tensor mat) {
auto value = opt_value.has_value() ? opt_value.value() : col;
auto result = SPMMMin::apply(rowptr, col, value, mat, opt_value.has_value());
return std::make_tuple(result[0], result[1]);
}
std::tuple<torch::Tensor, torch::Tensor>
spmm_max(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value, torch::Tensor mat) {
auto value = opt_value.has_value() ? opt_value.value() : col;
auto result = SPMMMax::apply(rowptr, col, value, mat, opt_value.has_value());
return std::make_tuple(result[0], result[1]);
}
static auto registry = torch::RegisterOperators() static auto registry = torch::RegisterOperators()
.op("torch_sparse::spmm_sum", &spmm_sum) .op("torch_sparse::spmm_sum", &spmm_sum)
.op("torch_sparse::spmm_mean", &spmm_mean); .op("torch_sparse::spmm_mean", &spmm_mean)
.op("torch_sparse::spmm_min", &spmm_min)
.op("torch_sparse::spmm_max", &spmm_max);
...@@ -7,10 +7,7 @@ from torch_sparse.matmul import matmul ...@@ -7,10 +7,7 @@ from torch_sparse.matmul import matmul
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
import torch_scatter import torch_scatter
from .utils import devices, grad_dtypes from .utils import reductions, devices, grad_dtypes
reductions = ['sum', 'mean', 'min', 'max']
reductions = ['sum', 'mean']
@pytest.mark.parametrize('dtype,device,reduce', @pytest.mark.parametrize('dtype,device,reduce',
......
import torch import torch
reductions = ['sum', 'add', 'mean', 'min', 'max']
dtypes = [torch.float, torch.double, torch.int, torch.long] dtypes = [torch.float, torch.double, torch.int, torch.long]
grad_dtypes = [torch.float, torch.double] grad_dtypes = [torch.float, torch.double]
......
...@@ -46,4 +46,5 @@ from .diag import set_diag, remove_diag ...@@ -46,4 +46,5 @@ from .diag import set_diag, remove_diag
from .add import add, add_, add_nnz, add_nnz_ from .add import add, add_, add_nnz, add_nnz_
from .mul import mul, mul_, mul_nnz, mul_nnz_ from .mul import mul, mul_, mul_nnz, mul_nnz_
from .reduce import sum, mean, min, max from .reduce import sum, mean, min, max
from .matmul import spmm_sum, spmm_add, spmm, matmul from .matmul import (spmm_sum, spmm_add, spmm_mean, spmm_min, spmm_max, spmm,
spspmm_sum, spspmm_add, spspmm, matmul)
import warnings import warnings
import os.path as osp import os.path as osp
from typing import Optional, Union from typing import Optional, Union, Tuple
import torch import torch
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
...@@ -29,8 +29,26 @@ except OSError: ...@@ -29,8 +29,26 @@ except OSError:
raise ImportError raise ImportError
return mat return mat
def spmm_min_max_placeholder(rowptr: torch.Tensor, col: torch.Tensor,
value: Optional[torch.Tensor],
mat: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
raise ImportError
return mat, mat
def spspmm_sum_placeholder(
rowptrA: torch.Tensor, colA: torch.Tensor,
valueA: Optional[torch.Tensor], rowptrB: torch.Tensor,
colB: torch.Tensor, valueB: Optional[torch.Tensor], K: int
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
raise ImportError
return rowptrA, colA, valueA
torch.ops.torch_sparse.spmm_sum = spmm_sum_placeholder torch.ops.torch_sparse.spmm_sum = spmm_sum_placeholder
torch.ops.torch_sparse.spmm_mean = spmm_mean_placeholder torch.ops.torch_sparse.spmm_mean = spmm_mean_placeholder
torch.ops.torch_sparse.spmm_min = spmm_min_max_placeholder
torch.ops.torch_sparse.spmm_max = spmm_min_max_placeholder
torch.ops.torch_sparse.spspmm_sum = spspmm_sum_placeholder
@torch.jit.script @torch.jit.script
...@@ -80,6 +98,20 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor: ...@@ -80,6 +98,20 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
colptr, csr2csc, other) colptr, csr2csc, other)
@torch.jit.script
def spmm_min(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr()
return torch.ops.torch_sparse.spmm_min(rowptr, col, value, other)
@torch.jit.script
def spmm_max(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr()
return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)
@torch.jit.script @torch.jit.script
def spmm(src: SparseTensor, other: torch.Tensor, def spmm(src: SparseTensor, other: torch.Tensor,
reduce: str = "sum") -> torch.Tensor: reduce: str = "sum") -> torch.Tensor:
...@@ -87,6 +119,37 @@ def spmm(src: SparseTensor, other: torch.Tensor, ...@@ -87,6 +119,37 @@ def spmm(src: SparseTensor, other: torch.Tensor,
return spmm_sum(src, other) return spmm_sum(src, other)
elif reduce == 'mean': elif reduce == 'mean':
return spmm_mean(src, other) return spmm_mean(src, other)
elif reduce == 'min':
return spmm_min(src, other)[0]
elif reduce == 'max':
return spmm_max(src, other)[0]
else:
raise ValueError
@torch.jit.script
def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
rowptrA, colA, valueA = src.csr()
rowptrB, colB, valueB = other.csr()
M, K = src.sparse_size(0), other.sparse_size(1)
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
rowptrA, colA, valueA, rowptrB, colB, valueB, K)
return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC,
sparse_sizes=torch.Size([M, K]), is_sorted=True)
@torch.jit.script
def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor:
return spspmm_sum(src, other)
@torch.jit.script
def spspmm(src: SparseTensor, other: SparseTensor,
reduce: str = "sum") -> SparseTensor:
if reduce == 'sum' or reduce == 'add':
return spspmm_sum(src, other)
elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
raise NotImplementedError
else: else:
raise ValueError raise ValueError
...@@ -95,172 +158,15 @@ def matmul(src: SparseTensor, other: Union[torch.Tensor, SparseTensor], ...@@ -95,172 +158,15 @@ def matmul(src: SparseTensor, other: Union[torch.Tensor, SparseTensor],
reduce: str = "sum"): reduce: str = "sum"):
if torch.is_tensor(other): if torch.is_tensor(other):
return spmm(src, other, reduce) return spmm(src, other, reduce)
elif isinstance(other, SparseTensor):
return spspmm(src, other, reduce)
else: else:
raise ValueError raise ValueError
SparseTensor.spmm = lambda self, other, reduce=None: spmm(self, other, reduce) SparseTensor.spmm = lambda self, other, reduce=None: spmm(self, other, reduce)
SparseTensor.spspmm = lambda self, other, reduce=None: spspmm(
self, other, reduce)
SparseTensor.matmul = lambda self, other, reduce=None: matmul( SparseTensor.matmul = lambda self, other, reduce=None: matmul(
self, other, reduce) self, other, reduce)
SparseTensor.__matmul__ = lambda self, other: matmul(self, other, 'sum') SparseTensor.__matmul__ = lambda self, other: matmul(self, other, 'sum')
# class SPMM(torch.autograd.Function):
# @staticmethod
# def forward(ctx, row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
# reduce):
# if mat.is_cuda:
# out, arg_out = torch.ops.torch_sparse_cuda.spmm(
# rowptr, col, value, mat, reduce)
# else:
# out, arg_out = torch.ops.torch_sparse_cpu.spmm(
# rowptr, col, value, mat, reduce)
# ctx.reduce = reduce
# ctx.save_for_backward(row, rowptr, col, value, mat, rowcount, colptr,
# csr2csc, arg_out)
# if reduce == 'min' or reduce == 'max':
# ctx.mark_non_differentiable(arg_out)
# return out, arg_out
# else:
# return out
# @staticmethod
# def backward(ctx, grad_out, *args):
# (row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
# arg_out) = ctx.saved_tensors
# invalid_arg_mask = arg_out_ind = None
# if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[3]
# or ctx.needs_input_grad[4]):
# invalid_arg_mask = arg_out == col.size(0)
# arg_out_ind = arg_out.masked_fill(invalid_arg_mask, -1)
# grad_value = None
# if ctx.needs_input_grad[3]:
# if ctx.reduce in ['sum', 'add', 'mean']:
# grad_value = ext(grad_out.is_cuda).spmm_val_bw(
# row, rowptr, col, mat, grad_out, ctx.reduce)
# elif ctx.reduce in ['min', 'max']:
# col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
# out = mat.gather(-2, col_tmp).mul_(grad_out)
# out.masked_fill_(invalid_arg_mask, 0)
# grad_value = scatter_add(out.flatten(), arg_out.flatten(),
# dim=0, dim_size=value.numel() + 1)
# grad_value = grad_value[:-1]
# grad_mat = None
# if ctx.needs_input_grad[4]:
# if ctx.reduce in ['sum', 'add']:
# value = value[csr2csc] if value is not None else value
# grad_mat, _ = ext(grad_out.is_cuda).spmm(
# colptr, row[csr2csc], value, grad_out, 'sum')
# elif ctx.reduce == 'mean':
# count = rowcount[row].to(mat.dtype).clamp_(min=1)
# value = count.pow_(-1) if value is None else value / count
# row = row[csr2csc]
# value = value[csr2csc] if value is not None else value
# grad_mat, _ = ext(grad_out.is_cuda).spmm(
# colptr, row, value, grad_out, 'sum')
# elif ctx.reduce in ['min', 'max']:
# if value is not None:
# value = value[arg_out_ind.flatten()].view_as(arg_out)
# value = value.mul_(grad_out)
# else:
# value = grad_out
# value.masked_fill_(invalid_arg_mask, 0)
# col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
# grad_mat = scatter_add(value, col_tmp, dim=-2,
# dim_size=mat.size(-2))
# return None, None, None, grad_value, grad_mat, None, None, None, None
# class SPSPMM(torch.autograd.Function):
# @staticmethod
# def forward(ctx, rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K):
# if rowptrA.is_cuda:
# rowptrC, colC, valueC = ext(True).spspmm(rowptrA, colA, valueA,
# rowptrB, colB, valueB, M,
# N, K)
# else:
# dtype = None
# if valueA is not None:
# dtype = valueA.dtype
# if valueB is not None:
# dtype = valueB.dtype
# if valueA is None:
# valueA = torch.ones(colA.numel(), dtype=dtype)
# A = scipy.sparse.csr_matrix((valueA, colA, rowptrA), (M, N))
# if valueB is None:
# valueB = torch.ones(colB.numel(), dtype=dtype)
# B = scipy.sparse.csr_matrix((valueB, colB, rowptrB), (N, K))
# C = A @ B
# rowptrC = torch.from_numpy(C.indptr).to(torch.int64)
# colC = torch.from_numpy(C.indices).to(torch.int64)
# valueC = torch.from_numpy(C.data)
# valueC = valueC.to(dtype) if dtype is not None else None
# ctx.mark_non_differentiable(rowptrC, colC)
# # We cannot return `NoneType` in torch.autograd :(
# if valueC is None:
# return rowptrC, colC
# else:
# return rowptrC, colC, valueC
# @staticmethod
# def backward(ctx, grad_indexC, grad_rowptrC, *args):
# grad_valueA = None
# if ctx.needs_input_grad[2]:
# raise NotImplementedError
# grad_valueB = None
# if ctx.needs_input_grad[5]:
# raise NotImplementedError
# return (None, None, grad_valueA, None, None, grad_valueB, None, None,
# None)
# def matmul(src, other, reduce='sum'):
# assert src.dim() == 2 and src.size(-1) == other.size(-2)
# # Sparse-Dense Matrix Multiplication.
# if torch.is_tensor(other):
# assert reduce in ['sum', 'add', 'mean', 'min', 'max']
# rowptr, col, value = src.csr()
# row = None
# if reduce in ['sum', 'add', 'mean'] and (src.requires_grad
# or other.requires_grad):
# row = src.storage.row
# rowcount = None
# if other.requires_grad and reduce in ['mean']:
# rowcount = src.storage.rowcount
# csr2csc = colptr = None
# if other.requires_grad and reduce in ['sum', 'add', 'mean']:
# csr2csc, colptr = src.storage.csr2csc, src.storage.colptr
# return SPMM.apply(row, rowptr, col, value, other, rowcount, colptr,
# csr2csc, reduce)
# # Sparse-Sparse Matrix Multiplication.
# elif isinstance(other, src.__class__):
# assert reduce in ['sum', 'add']
# assert src.dim() == 2 and other.dim() == 2
# data = SPSPMM.apply(*src.csr(), *other.csr(), src.size(0), src.size(1),
# other.size(1))
# (rowptr, col), value = data[:2], data[2] if len(data) == 3 else None
# sparse_size = torch.Size([src.size(0), other.size(1)])
# return src.__class__(rowptr=rowptr, col=col, value=value,
# sparse_size=sparse_size, is_sorted=True)
# raise ValueError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment