functional.py 21.7 KB
Newer Older
yan.yan's avatar
yan.yan committed
1
# Copyright 2021 Yan Yan
yan.yan's avatar
v2.1  
yan.yan committed
2
#
traveller59's avatar
traveller59 committed
3
4
5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
yan.yan's avatar
v2.1  
yan.yan committed
6
#
traveller59's avatar
traveller59 committed
7
#     http://www.apache.org/licenses/LICENSE-2.0
yan.yan's avatar
v2.1  
yan.yan committed
8
#
traveller59's avatar
traveller59 committed
9
10
11
12
13
14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import sys
yan.yan's avatar
add act  
yan.yan committed
16
import pickle
17

traveller59's avatar
traveller59 committed
18
19
20
import torch
from torch import nn
from torch.autograd import Function
yan.yan's avatar
yan.yan committed
21
from typing import Optional, TypeVar
yan.yan's avatar
yan.yan committed
22
from spconv.pytorch.core import SparseConvTensor
23
from spconv.tools import CUDAKernelTimer
24
from spconv.pytorch import ops, SparseConvTensor
yan.yan's avatar
yan.yan committed
25
from spconv.pytorch.constants import PYTORCH_VERSION
26
from spconv.debug_utils import spconv_save_debug_data
yan.yan's avatar
v2.1  
yan.yan committed
27
28
from torch.autograd.function import once_differentiable
import numpy as np
29
from pathlib import Path
30
31
from spconv.pytorch.hash import HashTable
from cumm.gemm.layout import to_stride
yan.yan's avatar
v2.1  
yan.yan committed
32
from typing import List
yan.yan's avatar
add act  
yan.yan committed
33
34
from functools import reduce
from cumm import tensorview as tv
35

36
_MAX_INT32 = 2147483647
37

yan.yan's avatar
yan.yan committed
38
39
_T = TypeVar("_T")

yan.yan's avatar
add act  
yan.yan committed
40

yan.yan's avatar
yan.yan committed
41
42
43
def identity_decorator(func: _T) -> _T:
    return func

yan.yan's avatar
yan.yan committed
44
45
46
47
if PYTORCH_VERSION >= [2, 5, 0]:
    import torch.amp as amp
    _TORCH_CUSTOM_FWD = amp.custom_fwd(cast_inputs=torch.float16, device_type="cuda")
    _TORCH_CUSTOM_BWD = amp.custom_bwd(device_type="cuda")
yan.yan's avatar
add act  
yan.yan committed
48

yan.yan's avatar
yan.yan committed
49
elif PYTORCH_VERSION >= [1, 6, 0]:
yan.yan's avatar
yan.yan committed
50
51
52
53
54
55
56
    import torch.cuda.amp as amp
    _TORCH_CUSTOM_FWD = amp.custom_fwd(cast_inputs=torch.float16)
    _TORCH_CUSTOM_BWD = amp.custom_bwd

else:
    _TORCH_CUSTOM_FWD = identity_decorator
    _TORCH_CUSTOM_BWD = identity_decorator
traveller59's avatar
traveller59 committed
57

yan.yan's avatar
add act  
yan.yan committed
58

traveller59's avatar
traveller59 committed
59
60
class SparseConvFunction(Function):
    @staticmethod
yan.yan's avatar
yan.yan committed
61
    @_TORCH_CUSTOM_FWD
62
63
64
65
66
67
68
    def forward(ctx,
                features,
                filters,
                indice_pairs,
                indice_pair_num,
                num_activate_out,
                algo,
yan.yan's avatar
add act  
yan.yan committed
69
70
71
72
73
                timer: CUDAKernelTimer = CUDAKernelTimer(False),
                bias: Optional[torch.Tensor] = None,
                act_alpha: float = 0.0,
                act_beta: float = 0.0,
                act_type: tv.gemm.Activation = tv.gemm.Activation.None_):
74
        ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
Yan Yan's avatar
Yan Yan committed
75
        ctx.algo = algo
76
        ctx.timer = timer
77
78
        try:
            return ops.indice_conv(features,
yan.yan's avatar
add act  
yan.yan committed
79
80
81
82
83
84
85
86
87
88
89
                                   filters,
                                   indice_pairs,
                                   indice_pair_num,
                                   num_activate_out,
                                   False,
                                   algo=algo,
                                   timer=timer,
                                   bias=bias,
                                   act_alpha=act_alpha,
                                   act_beta=act_beta,
                                   act_type=act_type)
90
91
92
93
94
95
        except Exception as e:
            msg = "[Exception|indice_conv]"
            msg += f"feat={features.shape},w={filters.shape},pair={indice_pairs.shape},"
            msg += f"pairnum={indice_pair_num},act={num_activate_out},algo={algo}"
            print(msg, file=sys.stderr)
            spconv_save_debug_data((indice_pairs, indice_pair_num))
yan.yan's avatar
add act  
yan.yan committed
96
            raise e
traveller59's avatar
traveller59 committed
97
98

    @staticmethod
yan.yan's avatar
v2.1  
yan.yan committed
99
    @once_differentiable
yan.yan's avatar
yan.yan committed
100
    @_TORCH_CUSTOM_BWD
traveller59's avatar
traveller59 committed
101
102
    def backward(ctx, grad_output):
        indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
103
        timer = ctx.timer
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        try:
            input_bp, filters_bp = ops.indice_conv_backward(features,
                                                            filters,
                                                            grad_output,
                                                            indice_pairs,
                                                            indice_pair_num,
                                                            False,
                                                            algo=ctx.algo,
                                                            timer=timer)
        except Exception as e:
            msg = "[Exception|indice_conv_backward]"
            msg += f"feat={features.shape},w={filters.shape},pair={indice_pairs.shape},"
            msg += f"pairnum={indice_pair_num},do={grad_output.shape}"
            print(msg, file=sys.stderr)
            spconv_save_debug_data((indice_pairs, indice_pair_num))
yan.yan's avatar
add act  
yan.yan committed
119
            raise e
120

yan.yan's avatar
add act  
yan.yan committed
121
        return input_bp, filters_bp, None, None, None, None, None, None, None, None, None
traveller59's avatar
traveller59 committed
122

123

traveller59's avatar
traveller59 committed
124
125
class SparseInverseConvFunction(Function):
    @staticmethod
yan.yan's avatar
yan.yan committed
126
    @_TORCH_CUSTOM_FWD
127
128
129
130
131
132
133
    def forward(ctx,
                features,
                filters,
                indice_pairs,
                indice_pair_num,
                num_activate_out,
                algo,
yan.yan's avatar
add act  
yan.yan committed
134
135
136
137
138
                timer: CUDAKernelTimer = CUDAKernelTimer(False),
                bias: Optional[torch.Tensor] = None,
                act_alpha: float = 0.0,
                act_beta: float = 0.0,
                act_type: tv.gemm.Activation = tv.gemm.Activation.None_):
139
        ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
Yan Yan's avatar
Yan Yan committed
140
        ctx.algo = algo
141
        ctx.timer = timer
142
143
        try:
            return ops.indice_conv(features,
yan.yan's avatar
add act  
yan.yan committed
144
145
146
147
148
149
150
151
152
153
154
155
                                   filters,
                                   indice_pairs,
                                   indice_pair_num,
                                   num_activate_out,
                                   True,
                                   False,
                                   algo=algo,
                                   timer=timer,
                                   bias=bias,
                                   act_alpha=act_alpha,
                                   act_beta=act_beta,
                                   act_type=act_type)
156
157
158
159
160
161
        except Exception as e:
            msg = "[Exception|indice_conv|inverse]"
            msg += f"feat={features.shape},w={filters.shape},pair={indice_pairs.shape},"
            msg += f"pairnum={indice_pair_num},act={num_activate_out},algo={algo}"
            print(msg, file=sys.stderr)
            spconv_save_debug_data((indice_pairs, indice_pair_num))
yan.yan's avatar
add act  
yan.yan committed
162
            raise e
traveller59's avatar
traveller59 committed
163
164

    @staticmethod
yan.yan's avatar
v2.1  
yan.yan committed
165
    @once_differentiable
yan.yan's avatar
yan.yan committed
166
    @_TORCH_CUSTOM_BWD
traveller59's avatar
traveller59 committed
167
168
    def backward(ctx, grad_output):
        indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
169
        timer = ctx.timer
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        try:
            input_bp, filters_bp = ops.indice_conv_backward(features,
                                                            filters,
                                                            grad_output,
                                                            indice_pairs,
                                                            indice_pair_num,
                                                            True,
                                                            False,
                                                            algo=ctx.algo,
                                                            timer=timer)
        except Exception as e:
            msg = "[Exception|indice_conv_backward|inverse]"
            msg += f"feat={features.shape},w={filters.shape},pair={indice_pairs.shape},"
            msg += f"pairnum={indice_pair_num},do={grad_output.shape}"
            print(msg, file=sys.stderr)
            spconv_save_debug_data((indice_pairs, indice_pair_num))
yan.yan's avatar
add act  
yan.yan committed
186
            raise e
187

yan.yan's avatar
add act  
yan.yan committed
188
        return input_bp, filters_bp, None, None, None, None, None, None, None, None, None
traveller59's avatar
traveller59 committed
189
190


yan.yan's avatar
v2.1  
yan.yan committed
191
192
class SparseImplicitGemmFunction(Function):
    @staticmethod
yan.yan's avatar
yan.yan committed
193
    @_TORCH_CUSTOM_FWD
194
195
196
197
198
    def forward(ctx,
                features: torch.Tensor,
                filters: torch.Tensor,
                pair_fwd: torch.Tensor,
                pair_bwd: torch.Tensor,
yan.yan's avatar
v2.1  
yan.yan committed
199
200
201
202
                pair_mask_fwd_splits: List[torch.Tensor],
                pair_mask_bwd_splits: List[torch.Tensor],
                mask_argsort_fwd_splits: List[torch.Tensor],
                mask_argsort_bwd_splits: List[torch.Tensor],
203
204
205
206
                num_activate_out: int,
                masks: List[np.ndarray],
                is_train: bool,
                is_subm: bool,
yan.yan's avatar
yan.yan committed
207
                timer: CUDAKernelTimer = CUDAKernelTimer(False),
yan.yan's avatar
add act  
yan.yan committed
208
209
210
211
212
                fp32_accum: Optional[bool] = None,
                bias: Optional[torch.Tensor] = None,
                act_alpha: float = 0.0,
                act_beta: float = 0.0,
                act_type: tv.gemm.Activation = tv.gemm.Activation.None_):
213
        try:
yan.yan's avatar
add act  
yan.yan committed
214
215
            out, mask_out, mask_width = ops.implicit_gemm(
                features, filters, pair_fwd, pair_mask_fwd_splits,
yan.yan's avatar
yan.yan committed
216
                mask_argsort_fwd_splits, num_activate_out, masks, is_train,
yan.yan's avatar
add act  
yan.yan committed
217
218
                is_subm, timer, fp32_accum, bias, act_alpha, act_beta,
                act_type)
219
220
221
222
223
        except Exception as e:
            msg = "[Exception|implicit_gemm]"
            msg += f"feat={features.shape},w={filters.shape},pair={pair_fwd.shape},"
            msg += f"act={num_activate_out},issubm={is_subm},istrain={is_train}"
            print(msg, file=sys.stderr)
yan.yan's avatar
add act  
yan.yan committed
224
225
226
227
228
229
            spconv_save_debug_data(
                (pair_fwd, pair_bwd, pair_mask_fwd_splits,
                 pair_mask_bwd_splits, mask_argsort_fwd_splits,
                 mask_argsort_bwd_splits, masks))
            raise e

yan.yan's avatar
v2.1  
yan.yan committed
230
231
232
        ctx.save_for_backward(features, filters, pair_fwd, pair_bwd)
        ctx.mask_width = mask_width
        ctx.mask_out = mask_out
233
        ctx.timer = timer
yan.yan's avatar
v2.1  
yan.yan committed
234
235
236
237
238
239
240
        ctx.pair_mask_fwd_splits = pair_mask_fwd_splits
        ctx.mask_argsort_fwd_splits = mask_argsort_fwd_splits
        ctx.pair_mask_bwd_splits = pair_mask_bwd_splits
        ctx.mask_argsort_bwd_splits = mask_argsort_bwd_splits
        # ctx.num_activate_out = num_activate_out
        ctx.masks = masks
        ctx.is_subm = is_subm
yan.yan's avatar
yan.yan committed
241
        ctx.fp32_accum = fp32_accum
yan.yan's avatar
v2.1  
yan.yan committed
242
243
244
245
        return out

    @staticmethod
    @once_differentiable
yan.yan's avatar
yan.yan committed
246
    @_TORCH_CUSTOM_BWD
yan.yan's avatar
v2.1  
yan.yan committed
247
248
249
250
251
252
253
254
255
256
257
    def backward(ctx, grad_output):
        features, filters, pair_fwd, pair_bwd = ctx.saved_tensors
        mask_width = ctx.mask_width
        mask_out = ctx.mask_out
        pair_mask_fwd_splits = ctx.pair_mask_fwd_splits
        mask_argsort_fwd_splits = ctx.mask_argsort_fwd_splits
        pair_mask_bwd_splits = ctx.pair_mask_bwd_splits
        mask_argsort_bwd_splits = ctx.mask_argsort_bwd_splits
        # num_activate_out = ctx.num_activate_out
        masks = ctx.masks
        is_subm = ctx.is_subm
258
        timer = ctx.timer
yan.yan's avatar
yan.yan committed
259
260
        fp32_accum = ctx.fp32_accum

261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        try:
            input_bp, filters_bp = ops.implicit_gemm_backward(
                features,
                filters,
                grad_output,
                pair_fwd,
                pair_bwd,
                pair_mask_fwd_splits,
                pair_mask_bwd_splits,
                mask_argsort_fwd_splits,
                mask_argsort_bwd_splits,
                mask_output_fwd=mask_out,
                masks=masks,
                mask_width=mask_width,
                is_subm=is_subm,
yan.yan's avatar
yan.yan committed
276
277
                timer=timer,
                fp32_accum=fp32_accum)
278
279
280
281
282
        except Exception as e:
            msg = "[Exception|implicit_gemm_backward]"
            msg += f"feat={features.shape},w={filters.shape},pair={pair_fwd.shape},"
            msg += f"issubm={is_subm},do={grad_output.shape}"
            print(msg, file=sys.stderr)
yan.yan's avatar
add act  
yan.yan committed
283
284
285
286
287
            spconv_save_debug_data(
                (pair_fwd, pair_bwd, pair_mask_fwd_splits,
                 pair_mask_bwd_splits, mask_argsort_fwd_splits,
                 mask_argsort_bwd_splits, masks))
            raise e
288

yan.yan's avatar
yan.yan committed
289
        None_9 = [None] * 16
yan.yan's avatar
yan.yan committed
290
        return (input_bp, filters_bp, *None_9)
yan.yan's avatar
v2.1  
yan.yan committed
291

292

traveller59's avatar
traveller59 committed
293
294
class SubMConvFunction(Function):
    @staticmethod
yan.yan's avatar
yan.yan committed
295
    @_TORCH_CUSTOM_FWD
296
297
298
299
300
301
302
    def forward(ctx,
                features,
                filters,
                indice_pairs,
                indice_pair_num,
                num_activate_out,
                algo,
yan.yan's avatar
add act  
yan.yan committed
303
304
305
306
307
                timer: CUDAKernelTimer = CUDAKernelTimer(False),
                bias: Optional[torch.Tensor] = None,
                act_alpha: float = 0.0,
                act_beta: float = 0.0,
                act_type: tv.gemm.Activation = tv.gemm.Activation.None_):
308
        ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
Yan Yan's avatar
Yan Yan committed
309
        ctx.algo = algo
310
        ctx.timer = timer
311
312
        try:
            return ops.indice_conv(features,
yan.yan's avatar
add act  
yan.yan committed
313
314
315
316
317
318
319
320
321
322
323
324
                                   filters,
                                   indice_pairs,
                                   indice_pair_num,
                                   num_activate_out,
                                   False,
                                   True,
                                   algo=algo,
                                   timer=timer,
                                   bias=bias,
                                   act_alpha=act_alpha,
                                   act_beta=act_beta,
                                   act_type=act_type)
325
326
327
328
329
330
        except Exception as e:
            msg = "[Exception|indice_conv|subm]"
            msg += f"feat={features.shape},w={filters.shape},pair={indice_pairs.shape},"
            msg += f"pairnum={indice_pair_num},act={num_activate_out},algo={algo}"
            print(msg, file=sys.stderr)
            spconv_save_debug_data((indice_pairs, indice_pair_num))
yan.yan's avatar
add act  
yan.yan committed
331
            raise e
traveller59's avatar
traveller59 committed
332
333

    @staticmethod
yan.yan's avatar
v2.1  
yan.yan committed
334
    @once_differentiable
yan.yan's avatar
yan.yan committed
335
    @_TORCH_CUSTOM_BWD
traveller59's avatar
traveller59 committed
336
337
    def backward(ctx, grad_output):
        indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
338
        timer = ctx.timer
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
        try:
            input_bp, filters_bp = ops.indice_conv_backward(features,
                                                            filters,
                                                            grad_output,
                                                            indice_pairs,
                                                            indice_pair_num,
                                                            False,
                                                            True,
                                                            algo=ctx.algo,
                                                            timer=timer)
        except Exception as e:
            msg = "[Exception|indice_conv_backward|subm]"
            msg += f"feat={features.shape},w={filters.shape},pair={indice_pairs.shape},"
            msg += f"pairnum={indice_pair_num},do={grad_output.shape}"
            print(msg, file=sys.stderr)
            spconv_save_debug_data((indice_pairs, indice_pair_num))
yan.yan's avatar
add act  
yan.yan committed
355
            raise e
356

yan.yan's avatar
add act  
yan.yan committed
357
        return input_bp, filters_bp, None, None, None, None, None, None, None, None, None
traveller59's avatar
traveller59 committed
358
359
360
361


class SparseMaxPoolFunction(Function):
    @staticmethod
yan.yan's avatar
yan.yan committed
362
    @_TORCH_CUSTOM_FWD
363
364
365
366
367
    def forward(ctx, features, indice_pairs, indice_pair_num,
                num_activate_out):
        out = ops.indice_maxpool(features, indice_pairs, indice_pair_num,
                                 num_activate_out)
        ctx.save_for_backward(indice_pairs, indice_pair_num, features, out)
traveller59's avatar
traveller59 committed
368
        return out
369

traveller59's avatar
traveller59 committed
370
    @staticmethod
yan.yan's avatar
v2.1  
yan.yan committed
371
    @once_differentiable
yan.yan's avatar
yan.yan committed
372
    @_TORCH_CUSTOM_BWD
traveller59's avatar
traveller59 committed
373
374
    def backward(ctx, grad_output):
        indice_pairs, indice_pair_num, features, out = ctx.saved_tensors
375
376
        input_bp = ops.indice_maxpool_backward(features, out, grad_output,
                                               indice_pairs, indice_pair_num)
traveller59's avatar
traveller59 committed
377
378
        return input_bp, None, None, None

379

yan.yan's avatar
v2.1  
yan.yan committed
380
381
class SparseMaxPoolImplicitGemmFunction(Function):
    @staticmethod
yan.yan's avatar
yan.yan committed
382
    @_TORCH_CUSTOM_FWD
383
384
385
386
    def forward(ctx, features: torch.Tensor, indice_pairs_fwd: torch.Tensor,
                indice_pairs_bwd: torch.Tensor, num_activate_out: int):
        out = ops.indice_maxpool_implicit_gemm(features, indice_pairs_fwd,
                                               num_activate_out)
yan.yan's avatar
v2.1  
yan.yan committed
387
388
389
390
391
        ctx.save_for_backward(indice_pairs_bwd, features, out)
        return out

    @staticmethod
    @once_differentiable
yan.yan's avatar
yan.yan committed
392
    @_TORCH_CUSTOM_BWD
yan.yan's avatar
v2.1  
yan.yan committed
393
394
    def backward(ctx, grad_output):
        indice_pairs_bwd, features, out = ctx.saved_tensors
395
396
        input_bp = ops.indice_maxpool_implicit_gemm_backward(
            features, out, grad_output, indice_pairs_bwd)
yan.yan's avatar
v2.1  
yan.yan committed
397
        return input_bp, None, None, None
398

yan.yan's avatar
add act  
yan.yan committed
399

Yan Yan's avatar
Yan Yan committed
400
401
402
403
class SparseAvgPoolImplicitGemmFunction(Function):
    @staticmethod
    @_TORCH_CUSTOM_FWD
    def forward(ctx, features: torch.Tensor, indice_pairs_fwd: torch.Tensor,
yan.yan's avatar
add act  
yan.yan committed
404
405
406
407
408
409
                indice_pairs_bwd: torch.Tensor, num_activate_out: int,
                calc_count):
        out, count = ops.indice_avgpool_implicit_gemm(features,
                                                      indice_pairs_fwd,
                                                      num_activate_out,
                                                      calc_count)
Yan Yan's avatar
Yan Yan committed
410
411
412
413
414
415
416
417
418
419
420
421
        ctx.save_for_backward(indice_pairs_bwd, features, out, count)
        return out

    @staticmethod
    @once_differentiable
    @_TORCH_CUSTOM_BWD
    def backward(ctx, grad_output):
        indice_pairs_bwd, features, out, count = ctx.saved_tensors
        input_bp = ops.indice_avgpool_implicit_gemm_backward(
            grad_output, indice_pairs_bwd, count)
        return input_bp, None, None, None, None

422

traveller59's avatar
traveller59 committed
423
indice_conv = SparseConvFunction.apply
yan.yan's avatar
v2.1  
yan.yan committed
424
implicit_gemm = SparseImplicitGemmFunction.apply
traveller59's avatar
traveller59 committed
425
426
427
indice_inverse_conv = SparseInverseConvFunction.apply
indice_subm_conv = SubMConvFunction.apply
indice_maxpool = SparseMaxPoolFunction.apply
yan.yan's avatar
v2.1  
yan.yan committed
428
indice_maxpool_implicit_gemm = SparseMaxPoolImplicitGemmFunction.apply
Yan Yan's avatar
Yan Yan committed
429
indice_avgpool_implicit_gemm = SparseAvgPoolImplicitGemmFunction.apply
yan.yan's avatar
yan.yan committed
430
431


432
433
434
def _indice_to_scalar(indices: torch.Tensor, shape: List[int]):
    assert indices.shape[1] == len(shape)
    stride = to_stride(np.array(shape, dtype=np.int64))
435
    scalar_inds = indices[:, -1].clone()
436
437
438
439
    for i in range(len(shape) - 1):
        scalar_inds += stride[i] * indices[:, i]
    return scalar_inds.contiguous()

yan.yan's avatar
add act  
yan.yan committed
440

441
def sparse_add_hash_based(*tens: SparseConvTensor):
yan.yan's avatar
yan.yan committed
442
443
444
445
446
447
448
    """ sparse add with misaligned indices.
    if you use sparse add, the indice_dict will be dropped and impossible
    to use inverse.
    There is only one situation that keep indices: there is one operand that
    its indices is output indices.
    
    """
449
    table_size = 0
yan.yan's avatar
yan.yan committed
450
451
452
    max_num_indices = 0
    max_num_indices_idx = 0
    for i, ten in enumerate(tens):
453
454
455
456
        assert ten.spatial_shape == tens[0].spatial_shape
        assert ten.batch_size == tens[0].batch_size
        assert ten.features.shape[1] == tens[0].features.shape[1]
        table_size += ten.features.shape[0]
yan.yan's avatar
yan.yan committed
457
458
459
        if max_num_indices < ten.features.shape[0]:
            max_num_indices_idx = i
            max_num_indices = ten.features.shape[0]
yan.yan's avatar
add act  
yan.yan committed
460

461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
    first = tens[0]
    feat = first.features
    shape = [first.batch_size, *first.spatial_shape]
    whole_shape = int(np.prod(shape))
    table_size *= 2
    k_type = torch.int32
    if whole_shape >= _MAX_INT32:
        k_type = torch.int64
    table = HashTable(first.features.device, k_type, torch.int32, table_size)
    scalars: List[torch.Tensor] = []
    for ten in tens:
        indices = ten.indices
        if whole_shape >= _MAX_INT32:
            indices = indices.long()
        scalar = _indice_to_scalar(indices, shape)
        scalars.append(scalar)
        table.insert(scalar)
    # assign arange to values of hash table
    count = table.assign_arange_()
    count_val = count.item()
yan.yan's avatar
add act  
yan.yan committed
481
482
483
484
485
486
    out_features = torch.zeros([int(count_val), feat.shape[1]],
                               dtype=feat.dtype,
                               device=feat.device)
    out_indices = torch.zeros([int(count_val), first.indices.shape[1]],
                              dtype=first.indices.dtype,
                              device=first.indices.device)
487
488
489
490
491
    for ten, scalar in zip(tens, scalars):
        out_inds, _ = table.query(scalar)
        out_inds = out_inds.long()
        out_features[out_inds] += ten.features
        out_indices[out_inds] = ten.indices
yan.yan's avatar
add act  
yan.yan committed
492
493
494
495
496
    res = SparseConvTensor(out_features,
                           out_indices,
                           first.spatial_shape,
                           first.batch_size,
                           benchmark=first.benchmark)
yan.yan's avatar
yan.yan committed
497
498
    if count_val == max_num_indices:
        res.indice_dict = tens[max_num_indices_idx].indice_dict
499
    res.benchmark_record = first.benchmark_record
yan.yan's avatar
add act  
yan.yan committed
500
    res._timer = first._timer
501
    res.thrust_allocator = first.thrust_allocator
yan.yan's avatar
add act  
yan.yan committed
502
503
    return res

504

505
506
507
508
509
510
511
def sparse_add(*tens: SparseConvTensor):
    """reuse torch.sparse. the internal is sort + unique 
    """
    max_num_indices = 0
    max_num_indices_idx = 0
    ten_ths: List[torch.Tensor] = []
    first = tens[0]
yan.yan's avatar
add act  
yan.yan committed
512
513
514
    res_shape = [
        first.batch_size, *first.spatial_shape, first.features.shape[1]
    ]
515

516
517
518
519
520
521
522
    for i, ten in enumerate(tens):
        assert ten.spatial_shape == tens[0].spatial_shape
        assert ten.batch_size == tens[0].batch_size
        assert ten.features.shape[1] == tens[0].features.shape[1]
        if max_num_indices < ten.features.shape[0]:
            max_num_indices_idx = i
            max_num_indices = ten.features.shape[0]
yan.yan's avatar
add act  
yan.yan committed
523
524
525
526
527
528
        ten_ths.append(
            torch.sparse_coo_tensor(ten.indices.T,
                                    ten.features,
                                    res_shape,
                                    requires_grad=True))

529
    c_th = reduce(lambda x, y: x + y, ten_ths).coalesce()
530
531
532
    c_th_inds = c_th.indices().T.contiguous().int()
    c_th_values = c_th.values()
    assert c_th_values.is_contiguous()
533

yan.yan's avatar
add act  
yan.yan committed
534
535
536
537
538
    res = SparseConvTensor(c_th_values,
                           c_th_inds,
                           first.spatial_shape,
                           first.batch_size,
                           benchmark=first.benchmark)
539
540
541
    if c_th_values.shape[0] == max_num_indices:
        res.indice_dict = tens[max_num_indices_idx].indice_dict
    res.benchmark_record = first.benchmark_record
yan.yan's avatar
add act  
yan.yan committed
542
    res._timer = first._timer
543
    res.thrust_allocator = first.thrust_allocator
yan.yan's avatar
add act  
yan.yan committed
544
    return res