scatter.cpp 8.36 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
#include <torch/script.h>

#include "cpu/scatter_cpu.h"
4
#include "utils.h"
rusty1s's avatar
rusty1s committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

#ifdef WITH_CUDA
#include "cuda/scatter_cuda.h"
#endif

torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) {
  if (src.dim() == 1)
    for (auto i = 0; i < dim; i++)
      src = src.unsqueeze(0);
  for (auto i = src.dim(); i < other.dim(); i++)
    src = src.unsqueeze(-1);
  src = src.expand(other.sizes().vec());
  return src;
}

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_fw(torch::Tensor src, torch::Tensor index, int64_t dim,
           torch::optional<torch::Tensor> optional_out,
           torch::optional<int64_t> dim_size, std::string reduce) {
  if (src.device().is_cuda()) {
#ifdef WITH_CUDA
    return scatter_cuda(src, index, dim, optional_out, dim_size, reduce);
#else
    AT_ERROR("Not compiled with CUDA support");
#endif
  } else {
    return scatter_cpu(src, index, dim, optional_out, dim_size, reduce);
  }
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;

class ScatterSum : public torch::autograd::Function<ScatterSum> {
public:
  static variable_list forward(AutogradContext *ctx, Variable src,
                               Variable index, int64_t dim,
                               torch::optional<Variable> optional_out,
                               torch::optional<int64_t> dim_size) {
rusty1s's avatar
rusty1s committed
44
    dim = dim < 0 ? src.dim() + dim : dim;
rusty1s's avatar
rusty1s committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    ctx->saved_data["dim"] = dim;
    ctx->saved_data["src_shape"] = src.sizes();
    index = broadcast(index, src, dim);
    auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
    auto out = std::get<0>(result);
    ctx->save_for_backward({index});
    if (optional_out.has_value())
      ctx->mark_dirty({optional_out.value()});
    return {out};
  }

  static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
    auto grad_out = grad_outs[0];
    auto saved = ctx->get_saved_variables();
    auto index = saved[0];
    auto dim = ctx->saved_data["dim"].toInt();
61
    auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
rusty1s's avatar
rusty1s committed
62
63
64
65
66
67
68
69
70
71
72
    auto grad_in = torch::gather(grad_out, dim, index, false);
    return {grad_in, Variable(), Variable(), Variable(), Variable()};
  }
};

class ScatterMean : public torch::autograd::Function<ScatterMean> {
public:
  static variable_list forward(AutogradContext *ctx, Variable src,
                               Variable index, int64_t dim,
                               torch::optional<Variable> optional_out,
                               torch::optional<int64_t> dim_size) {
rusty1s's avatar
rusty1s committed
73
    dim = dim < 0 ? src.dim() + dim : dim;
rusty1s's avatar
rusty1s committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    ctx->saved_data["dim"] = dim;
    ctx->saved_data["src_shape"] = src.sizes();

    auto old_index = index;

    index = broadcast(index, src, dim);
    auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
    auto out = std::get<0>(result);

    auto ones = torch::ones(old_index.sizes(), src.options());
    result = scatter_fw(ones, old_index,
                        old_index.dim() <= dim ? old_index.dim() - 1 : dim,
                        torch::nullopt, out.size(dim), "sum");
    auto count = std::get<0>(result);
    count.clamp_(1);
    count = broadcast(count, out, dim);
    out.div_(count);

    ctx->save_for_backward({index, count});
    if (optional_out.has_value())
      ctx->mark_dirty({optional_out.value()});
    return {out};
  }

  static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
    auto grad_out = grad_outs[0];
    auto saved = ctx->get_saved_variables();
    auto index = saved[0];
    auto count = saved[1];
    auto dim = ctx->saved_data["dim"].toInt();
104
    auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
rusty1s's avatar
rusty1s committed
105
106
107
108
109
110
111
112
113
114
115
116
117
    count = torch::gather(count, dim, index, false);
    auto grad_in = torch::gather(grad_out, dim, index, false);
    grad_in.div_(count);
    return {grad_in, Variable(), Variable(), Variable(), Variable()};
  }
};

class ScatterMin : public torch::autograd::Function<ScatterMin> {
public:
  static variable_list forward(AutogradContext *ctx, Variable src,
                               Variable index, int64_t dim,
                               torch::optional<Variable> optional_out,
                               torch::optional<int64_t> dim_size) {
rusty1s's avatar
rusty1s committed
118
    dim = dim < 0 ? src.dim() + dim : dim;
rusty1s's avatar
rusty1s committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    ctx->saved_data["dim"] = dim;
    ctx->saved_data["src_shape"] = src.sizes();

    index = broadcast(index, src, dim);
    auto result = scatter_fw(src, index, dim, optional_out, dim_size, "min");
    auto out = std::get<0>(result);
    auto arg_out = std::get<1>(result).value();
    ctx->save_for_backward({index, arg_out});
    ctx->mark_non_differentiable({arg_out});
    if (optional_out.has_value())
      ctx->mark_dirty({optional_out.value()});
    return {out, arg_out};
  }

  static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
    auto grad_out = grad_outs[0];
    auto saved = ctx->get_saved_variables();
    auto index = saved[0];
    auto arg_out = saved[1];
    auto dim = ctx->saved_data["dim"].toInt();
139
    auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
rusty1s's avatar
rusty1s committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    src_shape[dim] += 1;
    auto grad_in = torch::zeros(src_shape, grad_out.options());
    grad_in.scatter_(dim, arg_out, grad_out);
    grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
    return {grad_in, Variable(), Variable(), Variable(), Variable()};
  }
};

class ScatterMax : public torch::autograd::Function<ScatterMax> {
public:
  static variable_list forward(AutogradContext *ctx, Variable src,
                               Variable index, int64_t dim,
                               torch::optional<Variable> optional_out,
                               torch::optional<int64_t> dim_size) {
rusty1s's avatar
rusty1s committed
154
    dim = dim < 0 ? src.dim() + dim : dim;
rusty1s's avatar
rusty1s committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    ctx->saved_data["dim"] = dim;
    ctx->saved_data["src_shape"] = src.sizes();

    index = broadcast(index, src, dim);
    auto result = scatter_fw(src, index, dim, optional_out, dim_size, "max");
    auto out = std::get<0>(result);
    auto arg_out = std::get<1>(result).value();
    ctx->save_for_backward({index, arg_out});
    ctx->mark_non_differentiable({arg_out});
    if (optional_out.has_value())
      ctx->mark_dirty({optional_out.value()});
    return {out, arg_out};
  }

  static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
    auto grad_out = grad_outs[0];
    auto saved = ctx->get_saved_variables();
    auto index = saved[0];
    auto arg_out = saved[1];
    auto dim = ctx->saved_data["dim"].toInt();
175
    auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
rusty1s's avatar
rusty1s committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    src_shape[dim] += 1;
    auto grad_in = torch::zeros(src_shape, grad_out.options());
    grad_in.scatter_(dim, arg_out, grad_out);
    grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
    return {grad_in, Variable(), Variable(), Variable(), Variable()};
  }
};

torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
                          torch::optional<torch::Tensor> optional_out,
                          torch::optional<int64_t> dim_size) {
  return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
}

torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
                           torch::optional<torch::Tensor> optional_out,
                           torch::optional<int64_t> dim_size) {
  return ScatterMean::apply(src, index, dim, optional_out, dim_size)[0];
}

std::tuple<torch::Tensor, torch::Tensor>
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
            torch::optional<torch::Tensor> optional_out,
            torch::optional<int64_t> dim_size) {
  auto result = ScatterMin::apply(src, index, dim, optional_out, dim_size);
  return std::make_tuple(result[0], result[1]);
}

std::tuple<torch::Tensor, torch::Tensor>
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
            torch::optional<torch::Tensor> optional_out,
            torch::optional<int64_t> dim_size) {
  auto result = ScatterMax::apply(src, index, dim, optional_out, dim_size);
  return std::make_tuple(result[0], result[1]);
}

static auto registry = torch::RegisterOperators()
                           .op("torch_scatter::scatter_sum", &scatter_sum)
                           .op("torch_scatter::scatter_mean", &scatter_mean)
                           .op("torch_scatter::scatter_min", &scatter_min)
                           .op("torch_scatter::scatter_max", &scatter_max);