Commit 28cb8de4 authored by rusty1s's avatar rusty1s
Browse files

fix rowcount/colcount and added spmm mean

parent d613c5c0
......@@ -24,7 +24,7 @@ spmm_fw(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor spmm_value_bw(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce) {
if (rowptr.device().is_cuda()) {
if (row.device().is_cuda()) {
#ifdef WITH_CUDA
return spmm_value_bw_cuda(row, rowptr, col, mat, grad, reduce);
#else
......@@ -42,25 +42,21 @@ using torch::autograd::variable_list;
class SPMMSum : public torch::autograd::Function<SPMMSum> {
public:
static variable_list forward(AutogradContext *ctx,
torch::optional<Variable> optional_row,
torch::optional<Variable> opt_row,
Variable rowptr, Variable col, Variable value,
torch::optional<Variable> optional_colptr,
torch::optional<Variable> optional_csr2csc,
torch::optional<Variable> opt_colptr,
torch::optional<Variable> opt_csr2csc,
Variable mat) {
torch::Tensor row;
if (optional_row.has_value())
row = optional_row.value();
torch::optional<torch::Tensor> optional_value = torch::nullopt;
auto row = opt_row.has_value() ? opt_row.value() : torch::Tensor();
auto colptr = opt_colptr.has_value() ? opt_colptr.value() : torch::Tensor();
auto csr2csc =
opt_csr2csc.has_value() ? opt_csr2csc.value() : torch::Tensor();
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (value.numel() > 0)
optional_value = value;
torch::Tensor colptr;
if (optional_colptr.has_value())
colptr = optional_colptr.value();
torch::Tensor csr2csc;
if (optional_csr2csc.has_value())
csr2csc = optional_csr2csc.value();
auto out = std::get<0>(spmm_fw(rowptr, col, optional_value, mat, "sum"));
opt_value = value;
auto out = std::get<0>(spmm_fw(rowptr, col, opt_value, mat, "sum"));
ctx->save_for_backward({row, rowptr, col, value, colptr, csr2csc, mat});
return {out};
}
......@@ -68,30 +64,23 @@ public:
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto row = saved[0];
auto rowptr = saved[1];
auto col = saved[2];
auto value = saved[3];
torch::optional<torch::Tensor> optional_value = torch::nullopt;
if (value.numel() > 0)
optional_value = value;
auto colptr = saved[4];
auto csr2csc = saved[5];
auto mat = saved[6];
auto row = saved[0], rowptr = saved[1], col = saved[2], value = saved[3],
colptr = saved[4], csr2csc = saved[5], mat = saved[6];
auto grad_value = Variable();
if (optional_value.has_value() &&
if (value.numel() > 0 &&
torch::autograd::any_variable_requires_grad({value})) {
grad_value = spmm_value_bw(row, rowptr, col, mat, grad_out, "sum");
}
auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) {
if (optional_value.has_value())
optional_value = optional_value.value().index_select(0, csr2csc);
grad_mat = torch::zeros_like(mat);
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (value.numel() > 0)
opt_value = value.index_select(0, csr2csc);
grad_mat = std::get<0>(spmm_fw(colptr, row.index_select(0, csr2csc),
optional_value, grad_out, "sum"));
opt_value, grad_out, "sum"));
}
return {Variable(), Variable(), Variable(), grad_value,
......@@ -99,20 +88,89 @@ public:
}
};
torch::Tensor spmm_sum(torch::optional<torch::Tensor> optional_row,
class SPMMMean : public torch::autograd::Function<SPMMMean> {
public:
static variable_list forward(AutogradContext *ctx,
torch::optional<Variable> opt_row,
Variable rowptr, Variable col, Variable value,
torch::optional<Variable> opt_rowcount,
torch::optional<Variable> opt_colptr,
torch::optional<Variable> opt_csr2csc,
Variable mat) {
auto row = opt_row.has_value() ? opt_row.value() : torch::Tensor();
auto rowcount =
opt_rowcount.has_value() ? opt_rowcount.value() : torch::Tensor();
auto colptr = opt_colptr.has_value() ? opt_colptr.value() : torch::Tensor();
auto csr2csc =
opt_csr2csc.has_value() ? opt_csr2csc.value() : torch::Tensor();
torch::optional<torch::Tensor> opt_value = torch::nullopt;
if (value.numel() > 0)
opt_value = value;
auto out = std::get<0>(spmm_fw(rowptr, col, opt_value, mat, "mean"));
ctx->save_for_backward(
{row, rowptr, col, value, rowcount, colptr, csr2csc, mat});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto row = saved[0], rowptr = saved[1], col = saved[2], value = saved[3],
rowcount = saved[4], colptr = saved[5], csr2csc = saved[6],
mat = saved[7];
auto grad_value = Variable();
if (value.numel() > 0 &&
torch::autograd::any_variable_requires_grad({value})) {
grad_value = spmm_value_bw(row, rowptr, col, mat, grad_out, "mean");
}
auto grad_mat = Variable();
if (torch::autograd::any_variable_requires_grad({mat})) {
row = row.index_select(0, csr2csc);
rowcount = rowcount.toType(mat.scalar_type()).index_select(0, row);
rowcount.clamp_(1);
if (value.numel() > 0)
rowcount = value.index_select(0, csr2csc).div(rowcount);
else
rowcount.pow_(-1);
grad_mat = std::get<0>(spmm_fw(colptr, row, rowcount, grad_out, "sum"));
}
return {Variable(), Variable(), Variable(), grad_value,
Variable(), Variable(), Variable(), grad_mat};
}
};
torch::Tensor spmm_sum(torch::optional<torch::Tensor> opt_row,
torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_colptr,
torch::optional<torch::Tensor> optional_csr2csc,
torch::optional<torch::Tensor> opt_value,
torch::optional<torch::Tensor> opt_colptr,
torch::optional<torch::Tensor> opt_csr2csc,
torch::Tensor mat) {
// Since we cannot return an *optional* gradient, we need to convert
// `optional_value` to an empty sized tensor first :(
auto value = torch::Tensor();
if (optional_value.has_value())
value = optional_value.value();
return SPMMSum::apply(optional_row, rowptr, col, value, optional_colptr,
optional_csr2csc, mat)[0];
// `opt_value` to an empty sized tensor first :(
auto value = opt_value.has_value() ? opt_value.value() : torch::Tensor();
return SPMMSum::apply(opt_row, rowptr, col, value, opt_colptr, opt_csr2csc,
mat)[0];
}
torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row,
torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value,
torch::optional<torch::Tensor> opt_rowcount,
torch::optional<torch::Tensor> opt_colptr,
torch::optional<torch::Tensor> opt_csr2csc,
torch::Tensor mat) {
auto value = opt_value.has_value() ? opt_value.value() : torch::Tensor();
return SPMMMean::apply(opt_row, rowptr, col, value, opt_rowcount, opt_colptr,
opt_csr2csc, mat)[0];
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::spmm_sum", &spmm_sum);
static auto registry = torch::RegisterOperators()
.op("torch_sparse::spmm_sum", &spmm_sum)
.op("torch_sparse::spmm_mean", &spmm_mean);
......@@ -10,7 +10,7 @@ import torch_scatter
from .utils import devices, grad_dtypes
reductions = ['sum', 'mean', 'min', 'max']
reductions = ['sum']
reductions = ['sum', 'mean']
@pytest.mark.parametrize('dtype,device,reduce',
......
......@@ -19,7 +19,18 @@ except OSError:
raise ImportError
return mat
def spmm_mean_placeholder(row: Optional[torch.Tensor],
rowptr: torch.Tensor, col: torch.Tensor,
value: Optional[torch.Tensor],
rowcount: Optional[torch.Tensor],
colptr: Optional[torch.Tensor],
csr2csc: Optional[torch.Tensor],
mat: torch.Tensor) -> torch.Tensor:
raise ImportError
return mat
torch.ops.torch_sparse.spmm_sum = spmm_sum_placeholder
torch.ops.torch_sparse.spmm_mean = spmm_mean_placeholder
@torch.jit.script
......@@ -47,11 +58,35 @@ def spmm_add(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
return spmm_sum(src, other)
@torch.jit.script
def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
rowptr, col, value = src.csr()
row = src.storage._row
rowcount = src.storage._rowcount
csr2csc = src.storage._csr2csc
colptr = src.storage._colptr
if value is not None and value.requires_grad:
row = src.storage.row()
if other.requires_grad:
row = src.storage.row()
rowcount = src.storage.rowcount()
csr2csc = src.storage.csr2csc()
colptr = src.storage.colptr()
return torch.ops.torch_sparse.spmm_mean(row, rowptr, col, value, rowcount,
colptr, csr2csc, other)
@torch.jit.script
def spmm(src: SparseTensor, other: torch.Tensor,
reduce: str = "sum") -> torch.Tensor:
if reduce == 'sum' or reduce == 'add':
return spmm_sum(src, other)
elif reduce == 'mean':
return spmm_mean(src, other)
else:
raise ValueError
......
......@@ -274,7 +274,7 @@ class SparseStorage(object):
return rowcount
rowptr = self.rowptr()
rowcount = rowptr[1:] - rowptr[1:]
rowcount = rowptr[1:] - rowptr[:-1]
self._rowcount = rowcount
return rowcount
......@@ -306,7 +306,7 @@ class SparseStorage(object):
colptr = self._colptr
if colptr is not None:
colcount = colptr[1:] - colptr[1:]
colcount = colptr[1:] - colptr[:-1]
else:
colcount = scatter_add(torch.ones_like(self._col), self._col,
dim_size=self._sparse_sizes[1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment