segment.py 10.6 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import torch

rusty1s's avatar
rusty1s committed
3
from torch_scatter.helpers import min_value, max_value
rusty1s's avatar
rusty1s committed
4

rusty1s's avatar
rusty1s committed
5
if torch.cuda.is_available():
rusty1s's avatar
rusty1s committed
6
    from torch_scatter import segment_cuda, gather_cuda
rusty1s's avatar
rusty1s committed
7
8


rusty1s's avatar
rusty1s committed
9
10
11
class SegmentCOO(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, index, out, dim_size, reduce):
rusty1s's avatar
rusty1s committed
12
        assert reduce in ['add', 'mean', 'min', 'max']
rusty1s's avatar
rusty1s committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
        if out is not None:
            ctx.mark_dirty(out)
        ctx.reduce = reduce
        ctx.src_size = list(src.size())

        fill_value = 0
        if out is None:
            dim_size = index.max().item() + 1 if dim_size is None else dim_size
            size = list(src.size())
            size[index.dim() - 1] = dim_size

            if reduce == 'min':
                fill_value = max_value(src.dtype)
            elif reduce == 'max':
                fill_value = min_value(src.dtype)

            out = src.new_full(size, fill_value)

        out, arg_out = segment_cuda.segment_coo(src, index, out, reduce)

        if fill_value != 0:
            out.masked_fill_(out == fill_value, 0)

        ctx.save_for_backward(index, arg_out)

        if reduce == 'min' or reduce == 'max':
            return out, arg_out
        else:
            return out

    @staticmethod
    def backward(ctx, grad_out, *args):
        (index, arg_out), src_size = ctx.saved_tensors, ctx.src_size

        grad_src = None
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
49
            if ctx.reduce == 'add':
rusty1s's avatar
rusty1s committed
50
51
52
53
54
55
56
57
                grad_src = gather_cuda.gather_coo(grad_out, index,
                                                  grad_out.new_empty(src_size))
            elif ctx.reduce == 'mean':
                grad_src = gather_cuda.gather_coo(grad_out, index,
                                                  grad_out.new_empty(src_size))
                count = arg_out
                count = gather_cuda.gather_coo(
                    count, index, count.new_empty(src_size[:index.dim()]))
rusty1s's avatar
rusty1s committed
58
59
                for _ in range(grad_out.dim() - index.dim()):
                    count = count.unsqueeze(-1)
rusty1s's avatar
rusty1s committed
60
61
62
63
64
65
66
                grad_src.div_(count)
            elif ctx.reduce == 'min' or ctx.reduce == 'max':
                src_size[index.dim() - 1] += 1
                grad_src = grad_out.new_zeros(src_size).scatter_(
                    index.dim() - 1, arg_out, grad_out)
                grad_src = grad_src.narrow(index.dim() - 1, 0,
                                           src_size[index.dim() - 1] - 1)
rusty1s's avatar
rusty1s committed
67

rusty1s's avatar
rusty1s committed
68
        return grad_src, None, None, None, None
rusty1s's avatar
rusty1s committed
69
70


rusty1s's avatar
rusty1s committed
71
72
73
class SegmentCSR(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, indptr, out, reduce):
rusty1s's avatar
rusty1s committed
74
        assert reduce in ['add', 'mean', 'min', 'max']
rusty1s's avatar
rusty1s committed
75
76
77
78

        if out is not None:
            ctx.mark_dirty(out)
        ctx.reduce = reduce
rusty1s's avatar
rusty1s committed
79
        ctx.src_size = list(src.size())
rusty1s's avatar
rusty1s committed
80
81

        out, arg_out = segment_cuda.segment_csr(src, indptr, out, reduce)
rusty1s's avatar
rusty1s committed
82
        ctx.save_for_backward(indptr, arg_out)
rusty1s's avatar
rusty1s committed
83
84
85
86
        return out if arg_out is None else (out, arg_out)

    @staticmethod
    def backward(ctx, grad_out, *args):
rusty1s's avatar
rusty1s committed
87
        (indptr, arg_out), src_size = ctx.saved_tensors, ctx.src_size
rusty1s's avatar
rusty1s committed
88
89
90

        grad_src = None
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
91
            if ctx.reduce == 'add':
rusty1s's avatar
rusty1s committed
92
93
94
95
96
97
98
99
100
101
                grad_src = gather_cuda.gather_csr(grad_out, indptr,
                                                  grad_out.new_empty(src_size))
            elif ctx.reduce == 'mean':
                grad_src = gather_cuda.gather_csr(grad_out, indptr,
                                                  grad_out.new_empty(src_size))
                indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1)
                indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1)
                count = (indptr2 - indptr1).to(grad_src.dtype)
                count = gather_cuda.gather_csr(
                    count, indptr, count.new_empty(src_size[:indptr.dim()]))
rusty1s's avatar
rusty1s committed
102
103
                for _ in range(grad_out.dim() - indptr.dim()):
                    count = count.unsqueeze(-1)
rusty1s's avatar
rusty1s committed
104
105
106
107
108
109
110
                grad_src.div_(count)
            elif ctx.reduce == 'min' or ctx.reduce == 'max':
                src_size[indptr.dim() - 1] += 1
                grad_src = grad_out.new_zeros(src_size).scatter_(
                    indptr.dim() - 1, arg_out, grad_out)
                grad_src = grad_src.narrow(indptr.dim() - 1, 0,
                                           src_size[indptr.dim() - 1] - 1)
rusty1s's avatar
rusty1s committed
111
112
113
114

        return grad_src, None, None, None


rusty1s's avatar
rusty1s committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def segment_coo(src, index, out=None, dim_size=None, reduce="add"):
    r"""
    |

    .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
            master/docs/source/_figures/segment_coo.svg?sanitize=true
        :align: center
        :width: 400px

    |

    Reduces all values from the :attr:`src` tensor into :attr:`out` at the
    indices specified in the :attr:`index` tensor along the last dimension of
    :attr:`index`.
    For each value in :attr:`src`, its output index is specified by its index
    in :attr:`src` for dimensions outside of :obj:`index.dim() - 1` and by the
    corresponding value in :attr:`index` for dimension :obj:`index.dim() - 1`.
    The applied reduction is defined via the :attr:`reduce` argument.

    Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional and
    :math:`m`-dimensional tensors with
    size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
    :math:`(x_0, ..., x_{m-1}, x_m)`, respectively, then :attr:`out` must be an
    :math:`n`-dimensional tensor with size
    :math:`(x_0, ..., x_{m-1}, y, x_{m+1}, ..., x_{n-1})`.
    Moreover, the values of :attr:`index` must be between :math:`0` and
    :math:`y - 1` in ascending order.
    The :attr:`index` tensor supports broadcasting in case its dimensions do
    not match with :attr:`src`.
    For one-dimensional tensors with :obj:`reduce="add"`, the operation
    computes

    .. math::
        \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j

    where :math:`\sum_j` is over :math:`j` such that
    :math:`\mathrm{index}_j = i`.

    In contrast to :meth:`scatter`, this method expects values in :attr:`index`
    **to be sorted** along dimension :obj:`index.dim() - 1`.
    Due to the use of sorted indices, :meth:`segment_coo` is usually faster
    than the more general :meth:`scatter` operation.

    For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a
    second tensor representing the :obj:`argmin` and :obj:`argmax`,
    respectively.

    .. note::

        This operation is implemented via atomic operations on the GPU and is
        therefore **non-deterministic** since the order of parallel operations
        to the same value is undetermined.
        For floating-point variables, this results in a source of variance in
        the result.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The sorted indices of elements to segment.
            The number of dimensions of :attr:`index` needs to be less than or
            equal to :attr:`src`.
        out (Tensor, optional): The destination tensor. (default: :obj:`None`)
        dim_size (int, optional): If :attr:`out` is not given, automatically
            create output with size :attr:`dim_size` at dimension
            :obj:`index.dim() - 1`.
            If :attr:`dim_size` is not given, a minimal sized output tensor
            according to :obj:`index.max() + 1` is returned.
            (default: :obj:`None`)
        reduce (string, optional): The reduce operation (:obj:`"add"`,
            :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
            (default: :obj:`"add"`)

    :rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*

    .. code-block:: python

        from torch_scatter import segment_coo

        src = torch.randn(10, 6, 64)
        index = torch.tensor([0, 0, 1, 1, 1, 2])
        index = index.view(1, -1)  # Broadcasting in the first and last dim.

        out = segment_coo(src, index, reduce="add")

        print(out.size())

    .. code-block::

        torch.Size([10, 3, 64])
    """
rusty1s's avatar
rusty1s committed
204
    return SegmentCOO.apply(src, index, out, dim_size, reduce)
rusty1s's avatar
rusty1s committed
205
206


rusty1s's avatar
rusty1s committed
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def segment_csr(src, indptr, out=None, reduce="add"):
    r"""
    Reduces all values from the :attr:`src` tensor into :attr:`out` within the
    ranges specified in the :attr:`indptr` tensor along the last dimension of
    :attr:`indptr`.
    For each value in :attr:`src`, its output index is specified by its index
    in :attr:`src` for dimensions outside of :obj:`indptr.dim() - 1` and by the
    corresponding range index in :attr:`indptr` for dimension
    :obj:`indptr.dim() - 1`.
    The applied reduction is defined via the :attr:`reduce` argument.

    Formally, if :attr:`src` and :attr:`indptr` are :math:`n`-dimensional and
    :math:`m`-dimensional tensors with
    size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
    :math:`(x_0, ..., x_{m-1}, y)`, respectively, then :attr:`out` must be an
    :math:`n`-dimensional tensor with size
    :math:`(x_0, ..., x_{m-1}, y - 1, x_{m+1}, ..., x_{n-1})`.
    Moreover, the values of :attr:`indptr` must be between :math:`0` and
    :math:`x_m` in ascending order.
    The :attr:`indptr` tensor supports broadcasting in case its dimensions do
    not match with :attr:`src`.
    For one-dimensional tensors with :obj:`reduce="add"`, the operation
    computes

    .. math::
        \mathrm{out}_i =
        \sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+i]}~\mathrm{src}_j.

    Due to the use of index pointers, :meth:`segment_csr` is the fastest
    method to apply for grouped reductions.

    For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a
    second tensor representing the :obj:`argmin` and :obj:`argmax`,
    respectively.

    .. note::

        In contrast to :meth:`scatter()` and :meth:`segment_coo`, this
        operation is **fully-deterministic**.

    Args:
        src (Tensor): The source tensor.
        indptr (LongTensor): The index pointers between elements to segment.
            The number of dimensions of :attr:`index` needs to be less than or
            equal to :attr:`src`.
        out (Tensor, optional): The destination tensor. (default: :obj:`None`)
        reduce (string, optional): The reduce operation (:obj:`"add"`,
            :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
            (default: :obj:`"add"`)

    :rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*

    .. code-block:: python

        from torch_scatter import segment_csr

        src = torch.randn(10, 6, 64)
        indptr = torch.tensor([0, 2, 5, 6])
        indptr = indptr.view(1, -1)  # Broadcasting in the first and last dim.

        out = segment_csr(src, indptr, reduce="add")

        print(out.size())

    .. code-block::

        torch.Size([10, 3, 64])
    """
rusty1s's avatar
rusty1s committed
275
    return SegmentCSR.apply(src, indptr, out, reduce)