Commit 68f4609c authored by rusty1s's avatar rusty1s
Browse files

added scatter_mul back in

parent 12722728
...@@ -70,6 +70,37 @@ public: ...@@ -70,6 +70,37 @@ public:
} }
}; };
class ScatterMul : public torch::autograd::Function<ScatterMul> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "mul");
auto out = std::get<0>(result);
ctx->save_for_backward({src, index, out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto src = saved[0];
auto index = saved[1];
auto out = saved[2];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::gather(grad_out * out, dim, index, false).div_(src);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMean : public torch::autograd::Function<ScatterMean> { class ScatterMean : public torch::autograd::Function<ScatterMean> {
public: public:
static variable_list forward(AutogradContext *ctx, Variable src, static variable_list forward(AutogradContext *ctx, Variable src,
...@@ -197,6 +228,12 @@ torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -197,6 +228,12 @@ torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0]; return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
} }
torch::Tensor scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterMul::apply(src, index, dim, optional_out, dim_size)[0];
}
torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim, torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) { torch::optional<int64_t> dim_size) {
...@@ -221,6 +258,7 @@ scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -221,6 +258,7 @@ scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
static auto registry = torch::RegisterOperators() static auto registry = torch::RegisterOperators()
.op("torch_scatter::scatter_sum", &scatter_sum) .op("torch_scatter::scatter_sum", &scatter_sum)
.op("torch_scatter::scatter_mul", &scatter_mul)
.op("torch_scatter::scatter_mean", &scatter_mean) .op("torch_scatter::scatter_mean", &scatter_mean)
.op("torch_scatter::scatter_min", &scatter_min) .op("torch_scatter::scatter_min", &scatter_min)
.op("torch_scatter::scatter_max", &scatter_max); .op("torch_scatter::scatter_max", &scatter_max);
...@@ -7,6 +7,8 @@ import torch_scatter ...@@ -7,6 +7,8 @@ import torch_scatter
from .utils import reductions, tensor, dtypes, devices from .utils import reductions, tensor, dtypes, devices
reductions = reductions + ['mul']
tests = [ tests = [
{ {
'src': [1, 3, 2, 4, 5, 6], 'src': [1, 3, 2, 4, 5, 6],
...@@ -14,6 +16,7 @@ tests = [ ...@@ -14,6 +16,7 @@ tests = [
'dim': 0, 'dim': 0,
'sum': [3, 12, 0, 6], 'sum': [3, 12, 0, 6],
'add': [3, 12, 0, 6], 'add': [3, 12, 0, 6],
'mul': [2, 60, 1, 6],
'mean': [1.5, 4, 0, 6], 'mean': [1.5, 4, 0, 6],
'min': [1, 3, 0, 6], 'min': [1, 3, 0, 6],
'arg_min': [0, 1, 6, 5], 'arg_min': [0, 1, 6, 5],
...@@ -26,6 +29,7 @@ tests = [ ...@@ -26,6 +29,7 @@ tests = [
'dim': 0, 'dim': 0,
'sum': [[4, 6], [21, 24], [0, 0], [11, 12]], 'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
'add': [[4, 6], [21, 24], [0, 0], [11, 12]], 'add': [[4, 6], [21, 24], [0, 0], [11, 12]],
'mul': [[1 * 3, 2 * 4], [5 * 7 * 9, 6 * 8 * 10], [1, 1], [11, 12]],
'mean': [[2, 3], [7, 8], [0, 0], [11, 12]], 'mean': [[2, 3], [7, 8], [0, 0], [11, 12]],
'min': [[1, 2], [5, 6], [0, 0], [11, 12]], 'min': [[1, 2], [5, 6], [0, 0], [11, 12]],
'arg_min': [[0, 0], [1, 1], [6, 6], [5, 5]], 'arg_min': [[0, 0], [1, 1], [6, 6], [5, 5]],
...@@ -38,6 +42,7 @@ tests = [ ...@@ -38,6 +42,7 @@ tests = [
'dim': 1, 'dim': 1,
'sum': [[4, 21, 0, 11], [12, 18, 12, 0]], 'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
'add': [[4, 21, 0, 11], [12, 18, 12, 0]], 'add': [[4, 21, 0, 11], [12, 18, 12, 0]],
'mul': [[1 * 3, 5 * 7 * 9, 1, 11], [2 * 4 * 6, 8 * 10, 12, 1]],
'mean': [[2, 7, 0, 11], [4, 9, 12, 0]], 'mean': [[2, 7, 0, 11], [4, 9, 12, 0]],
'min': [[1, 5, 0, 11], [2, 8, 12, 0]], 'min': [[1, 5, 0, 11], [2, 8, 12, 0]],
'arg_min': [[0, 1, 6, 5], [0, 2, 5, 6]], 'arg_min': [[0, 1, 6, 5], [0, 2, 5, 6]],
...@@ -50,6 +55,7 @@ tests = [ ...@@ -50,6 +55,7 @@ tests = [
'dim': 1, 'dim': 1,
'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], 'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], 'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
'mul': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 11 * 13]]],
'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]], 'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]],
'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]], 'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]],
'arg_min': [[[0, 0], [1, 1], [3, 3]], [[1, 1], [3, 3], [0, 0]]], 'arg_min': [[[0, 0], [1, 1], [3, 3]], [[1, 1], [3, 3], [0, 0]]],
...@@ -62,6 +68,7 @@ tests = [ ...@@ -62,6 +68,7 @@ tests = [
'dim': 1, 'dim': 1,
'sum': [[4], [6]], 'sum': [[4], [6]],
'add': [[4], [6]], 'add': [[4], [6]],
'mul': [[3], [8]],
'mean': [[2], [3]], 'mean': [[2], [3]],
'min': [[1], [2]], 'min': [[1], [2]],
'arg_min': [[0], [0]], 'arg_min': [[0], [0]],
...@@ -74,6 +81,7 @@ tests = [ ...@@ -74,6 +81,7 @@ tests = [
'dim': 1, 'dim': 1,
'sum': [[[4, 4]], [[6, 6]]], 'sum': [[[4, 4]], [[6, 6]]],
'add': [[[4, 4]], [[6, 6]]], 'add': [[[4, 4]], [[6, 6]]],
'mul': [[[3, 3]], [[8, 8]]],
'mean': [[[2, 2]], [[3, 3]]], 'mean': [[[2, 2]], [[3, 3]]],
'min': [[[1, 1]], [[2, 2]]], 'min': [[[1, 1]], [[2, 2]]],
'arg_min': [[[0, 0]], [[0, 0]]], 'arg_min': [[[0, 0]], [[0, 0]]],
...@@ -125,6 +133,8 @@ def test_out(test, reduce, dtype, device): ...@@ -125,6 +133,8 @@ def test_out(test, reduce, dtype, device):
if reduce == 'sum' or reduce == 'add': if reduce == 'sum' or reduce == 'add':
expected = expected - 2 expected = expected - 2
elif reduce == 'mul':
expected = out # We can not really test this here.
elif reduce == 'mean': elif reduce == 'mean':
expected = out # We can not really test this here. expected = out # We can not really test this here.
elif reduce == 'min': elif reduce == 'min':
......
...@@ -58,8 +58,8 @@ if torch.cuda.is_available() and torch.version.cuda: # pragma: no cover ...@@ -58,8 +58,8 @@ if torch.cuda.is_available() and torch.version.cuda: # pragma: no cover
f'{major}.{minor}. Please reinstall the torch_scatter that ' f'{major}.{minor}. Please reinstall the torch_scatter that '
f'matches your PyTorch install.') f'matches your PyTorch install.')
from .scatter import (scatter_sum, scatter_add, scatter_mean, scatter_min, from .scatter import (scatter_sum, scatter_add, scatter_mul, scatter_mean,
scatter_max, scatter) # noqa scatter_min, scatter_max, scatter) # noqa
from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr, from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr,
segment_min_csr, segment_max_csr, segment_csr, segment_min_csr, segment_max_csr, segment_csr,
gather_csr) # noqa gather_csr) # noqa
...@@ -72,6 +72,7 @@ from .composite import (scatter_std, scatter_logsumexp, scatter_softmax, ...@@ -72,6 +72,7 @@ from .composite import (scatter_std, scatter_logsumexp, scatter_softmax,
__all__ = [ __all__ = [
'scatter_sum', 'scatter_sum',
'scatter_add', 'scatter_add',
'scatter_mul',
'scatter_mean', 'scatter_mean',
'scatter_min', 'scatter_min',
'scatter_max', 'scatter_max',
......
...@@ -31,6 +31,13 @@ def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -31,6 +31,13 @@ def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
return scatter_sum(src, index, dim, out, dim_size) return scatter_sum(src, index, dim, out, dim_size)
@torch.jit.script
def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size)
@torch.jit.script @torch.jit.script
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
...@@ -127,8 +134,8 @@ def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -127,8 +134,8 @@ def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
with size :attr:`dim_size` at dimension :attr:`dim`. with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor If :attr:`dim_size` is not given, a minimal sized output tensor
according to :obj:`index.max() + 1` is returned. according to :obj:`index.max() + 1` is returned.
:param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mean"`, :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mul"`,
:obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`) :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
:rtype: :class:`Tensor` :rtype: :class:`Tensor`
...@@ -150,6 +157,8 @@ def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -150,6 +157,8 @@ def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
""" """
if reduce == 'sum' or reduce == 'add': if reduce == 'sum' or reduce == 'add':
return scatter_sum(src, index, dim, out, dim_size) return scatter_sum(src, index, dim, out, dim_size)
if reduce == 'mul':
return scatter_mul(src, index, dim, out, dim_size)
elif reduce == 'mean': elif reduce == 'mean':
return scatter_mean(src, index, dim, out, dim_size) return scatter_mean(src, index, dim, out, dim_size)
elif reduce == 'min': elif reduce == 'min':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment