conv.py 43.4 KB
Newer Older
yan.yan's avatar
yan.yan committed
1
# Copyright 2021 Yan Yan
yan.yan's avatar
v2.1  
yan.yan committed
2
#
traveller59's avatar
traveller59 committed
3
4
5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
yan.yan's avatar
v2.1  
yan.yan committed
6
#
traveller59's avatar
traveller59 committed
7
#     http://www.apache.org/licenses/LICENSE-2.0
yan.yan's avatar
v2.1  
yan.yan committed
8
#
traveller59's avatar
traveller59 committed
9
10
11
12
13
14
15
16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import time
17
import sys
yan.yan's avatar
bug fix  
yan.yan committed
18
from typing import List, Optional, Tuple, Union
traveller59's avatar
traveller59 committed
19
20
21
22
23
24
25

import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter

yan.yan's avatar
yan.yan committed
26
from spconv import pytorch as spconv
yan.yan's avatar
yan.yan committed
27
from spconv import SPCONV_VERSION_NUMBERS
yan.yan's avatar
v2.1  
yan.yan committed
28
from spconv.core import ConvAlgo
29
from spconv.debug_utils import spconv_save_debug_data
30
from spconv.pytorch import functional as Fsp
yan.yan's avatar
yan.yan committed
31
from spconv.pytorch import ops
yan.yan's avatar
yan.yan committed
32
from spconv.cppconstants import CPU_ONLY_BUILD
33
from spconv.pytorch.core import IndiceData, SparseConvTensor, ImplicitGemmIndiceData, expand_nd
yan.yan's avatar
yan.yan committed
34
from spconv.pytorch.modules import SparseModule
yan.yan's avatar
yan.yan committed
35
from spconv.constants import SAVED_WEIGHT_LAYOUT, ALL_WEIGHT_IS_KRSC
36
from spconv.utils import nullcontext
yan.yan's avatar
yan.yan committed
37
from torch.nn.init import calculate_gain
traveller59's avatar
traveller59 committed
38

yan.yan's avatar
yan.yan committed
39
FILTER_HWIO = False
yan.yan's avatar
v2.1  
yan.yan committed
40

yan.yan's avatar
yan.yan committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54

def expand_nd(val: Union[int, List[int], Tuple[int, ...]], ndim: int) -> List[int]:
    if isinstance(val, int):
        val = [val] * ndim
    elif isinstance(val, list):
        assert len(val) == ndim
    elif isinstance(val, tuple):
        assert len(val) == ndim
        return [*val]
    else:
        raise NotImplementedError
    return val


traveller59's avatar
traveller59 committed
55
class SparseConvolution(SparseModule):
56
57
    __constants__ = [
        'stride', 'padding', 'dilation', 'groups', 'bias', 'subm', 'inverse',
yan.yan's avatar
v2.1  
yan.yan committed
58
        'transposed', 'output_padding'
59
60
    ]

traveller59's avatar
traveller59 committed
61
    def __init__(self,
yan.yan's avatar
bug fix  
yan.yan committed
62
63
64
                 ndim: int,
                 in_channels: int,
                 out_channels: int,
yan.yan's avatar
v2.1  
yan.yan committed
65
66
67
68
69
70
71
72
73
74
75
76
                 kernel_size: Union[int, List[int], Tuple[int, ...]] = 3,
                 stride: Union[int, List[int], Tuple[int, ...]] = 1,
                 padding: Union[int, List[int], Tuple[int, ...]] = 0,
                 dilation: Union[int, List[int], Tuple[int, ...]] = 1,
                 groups: Union[int, List[int], Tuple[int, ...]] = 1,
                 bias: bool = True,
                 subm: bool = False,
                 output_padding: Union[int, List[int], Tuple[int, ...]] = 0,
                 transposed: bool = False,
                 inverse: bool = False,
                 indice_key: Optional[str] = None,
                 algo: Optional[ConvAlgo] = None,
yan.yan's avatar
yan.yan committed
77
                 fp32_accum: Optional[bool] = None,
yanyan's avatar
yanyan committed
78
79
                 name=None):
        super(SparseConvolution, self).__init__(name=name)
yan.yan's avatar
v2.1  
yan.yan committed
80
        assert groups == 1, "don't support groups for now"
traveller59's avatar
traveller59 committed
81
82
83
        self.ndim = ndim
        self.in_channels = in_channels
        self.out_channels = out_channels
84
        self.kernel_size = expand_nd(ndim, kernel_size)
yan.yan's avatar
yan.yan committed
85
86
87
88
89
90
        self.stride = expand_nd(ndim, stride)
        kv = int(np.prod(self.kernel_size))
        kv_stride = int(np.prod(self.stride))
        self.dilation = expand_nd(ndim, dilation)
        self.padding = expand_nd(ndim, padding)

yan.yan's avatar
v2.1  
yan.yan committed
91
        self.conv1x1 = kv == 1
92
93
94
        # TODO we should deprecate support for ksize == 1 but stride != 1.
        if not subm:
            self.conv1x1 &= kv_stride == 1
yan.yan's avatar
yan.yan committed
95
            if self.conv1x1:
yan.yan's avatar
yan.yan committed
96
97
98
                assert self.padding == [
                    0
                ] * ndim, "padding must be zero for 1x1 conv (k=1,s=1)"
traveller59's avatar
traveller59 committed
99
100
        self.transposed = transposed
        self.inverse = inverse
101
        self.output_padding = expand_nd(ndim, output_padding)
traveller59's avatar
traveller59 committed
102
103
104
        self.groups = groups
        self.subm = subm
        self.indice_key = indice_key
yan.yan's avatar
v2.1  
yan.yan committed
105
        if algo is None:
yan.yan's avatar
yan.yan committed
106
            if kv <= 32 and not CPU_ONLY_BUILD:
yan.yan's avatar
v2.1  
yan.yan committed
107
108
109
110
111
112
113
114
                if kv < 8:
                    algo = ConvAlgo.MaskImplicitGemm
                else:
                    algo = ConvAlgo.MaskImplicitGemm
            else:
                algo = ConvAlgo.Native
        if kv > 32:
            assert algo == ConvAlgo.Native, "implicit gemm don't support kv >= 32 for now"
yan.yan's avatar
yan.yan committed
115
116
        if CPU_ONLY_BUILD:
            assert algo == ConvAlgo.Native, "cpu only build only support native algorithm"
yan.yan's avatar
yan.yan committed
117
        self.algo = algo
yan.yan's avatar
yan.yan committed
118
        self.fp32_accum = fp32_accum
yan.yan's avatar
v2.1  
yan.yan committed
119
        # self.algo = ConvAlgo.Native
yan.yan's avatar
yan.yan committed
120
        if self.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC:
yan.yan's avatar
v2.1  
yan.yan committed
121
122
123
            if FILTER_HWIO:
                # RSCK
                self.weight = Parameter(
124
                    torch.Tensor(*self.kernel_size, in_channels, out_channels))
yan.yan's avatar
v2.1  
yan.yan committed
125
126
127
            else:
                # RSKC
                self.weight = Parameter(
128
                    torch.Tensor(*self.kernel_size, out_channels, in_channels))
yan.yan's avatar
yan.yan committed
129
        else:
yan.yan's avatar
v2.1  
yan.yan committed
130
            # KRSC
yan.yan's avatar
yan.yan committed
131
            self.weight = Parameter(
132
                torch.Tensor(out_channels, *self.kernel_size, in_channels))
yan.yan's avatar
v2.1  
yan.yan committed
133

traveller59's avatar
traveller59 committed
134
135
136
137
138
139
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

yan.yan's avatar
yan.yan committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        self._register_load_state_dict_pre_hook(self._load_weight_different_layout)

    def _load_weight_different_layout(
            self, state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs):
        if not SAVED_WEIGHT_LAYOUT:
            return
        key = prefix + "weight"
        assert key in state_dict
        ndim = self.ndim
        if SAVED_WEIGHT_LAYOUT == "RSKC":
            state_dict[key] = state_dict[key].permute(ndim, *range(ndim), ndim + 1).contiguous()
        elif SAVED_WEIGHT_LAYOUT == "RSCK":
            state_dict[key] = state_dict[key].permute(ndim + 1, *range(ndim), ndim).contiguous()

        if ALL_WEIGHT_IS_KRSC or self.algo != ConvAlgo.Native:
            # in spconv 2.2, we only support KRSC layout.
            if SAVED_WEIGHT_LAYOUT == "RSKC":
                state_dict[key] = state_dict[key].permute(ndim, *range(ndim), ndim + 1).contiguous()
            elif SAVED_WEIGHT_LAYOUT == "RSCK":
                state_dict[key] = state_dict[key].permute(ndim + 1, *range(ndim), ndim).contiguous()

        else:
            if self.algo == ConvAlgo.Native:
                # to RSCK
                if SAVED_WEIGHT_LAYOUT == "RSKC":
                    state_dict[key] = state_dict[key].permute(*range(ndim), ndim + 1, ndim).contiguous()
                elif SAVED_WEIGHT_LAYOUT == "KRSC":
                    state_dict[key] = state_dict[key].permute(*range(1, ndim + 1), 0, ndim + 1).contiguous()


yan.yan's avatar
v2.1  
yan.yan committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    def extra_repr(self):
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        if self.algo is not None:
            s += f', algo={self.algo}'
        return s.format(**self.__dict__)

yan.yan's avatar
yan.yan committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    def _calculate_fan_in_and_fan_out(self):
        receptive_field_size = 1
        # math.prod is not always available, accumulate the product manually
        # we could use functools.reduce but that is not supported by TorchScript
        for s in self.kernel_size:
            receptive_field_size *= s
        fan_in = self.in_channels * receptive_field_size
        fan_out = self.out_channels * receptive_field_size
        return fan_in, fan_out

    def _calculate_correct_fan(self, mode):
        mode = mode.lower()
        valid_modes = ['fan_in', 'fan_out']
        if mode not in valid_modes:
yan.yan's avatar
yan.yan committed
202
203
204
            raise ValueError(
                "Mode {} not supported, please use one of {}".format(
                    mode, valid_modes))
yan.yan's avatar
yan.yan committed
205
206
207
208

        fan_in, fan_out = self._calculate_fan_in_and_fan_out()
        return fan_in if mode == 'fan_in' else fan_out

yan.yan's avatar
yan.yan committed
209
210
211
212
213
    def _custom_kaiming_uniform_(self,
                                 tensor,
                                 a=0,
                                 mode='fan_in',
                                 nonlinearity='leaky_relu'):
yan.yan's avatar
yan.yan committed
214
215
216
217
218
        r"""same as torch.init.kaiming_uniform_, with KRSC layout support
        """
        fan = self._calculate_correct_fan(mode)
        gain = calculate_gain(nonlinearity, a)
        std = gain / math.sqrt(fan)
yan.yan's avatar
yan.yan committed
219
220
        bound = math.sqrt(
            3.0) * std  # Calculate uniform bounds from standard deviation
yan.yan's avatar
yan.yan committed
221
222
223
        with torch.no_grad():
            return tensor.uniform_(-bound, bound)

traveller59's avatar
traveller59 committed
224
    def reset_parameters(self):
yan.yan's avatar
yan.yan committed
225
        self._custom_kaiming_uniform_(self.weight, a=math.sqrt(5))
traveller59's avatar
traveller59 committed
226
        if self.bias is not None:
yan.yan's avatar
yan.yan committed
227
            fan_in, _ = self._calculate_fan_in_and_fan_out()
traveller59's avatar
traveller59 committed
228
229
230
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

yanyan's avatar
yanyan committed
231
232
    def forward(self, input: SparseConvTensor):
        assert isinstance(input, SparseConvTensor)
yan.yan's avatar
v2.1  
yan.yan committed
233
234
        assert input.features.shape[
            1] == self.in_channels, "channel size mismatch"
traveller59's avatar
traveller59 committed
235
236
237
238
239
240
241
242
        features = input.features
        device = features.device
        indices = input.indices
        spatial_shape = input.spatial_shape
        batch_size = input.batch_size
        if not self.subm:
            if self.transposed:
                out_spatial_shape = ops.get_deconv_output_size(
traveller59's avatar
traveller59 committed
243
244
                    spatial_shape, self.kernel_size, self.stride, self.padding,
                    self.dilation, self.output_padding)
traveller59's avatar
traveller59 committed
245
246
            else:
                out_spatial_shape = ops.get_conv_output_size(
traveller59's avatar
traveller59 committed
247
248
                    spatial_shape, self.kernel_size, self.stride, self.padding,
                    self.dilation)
traveller59's avatar
traveller59 committed
249
250
        else:
            out_spatial_shape = spatial_shape
251
        # print(self._sparse_unique_name, spatial_shape, out_spatial_shape)
traveller59's avatar
traveller59 committed
252
253
        # input.update_grid(out_spatial_shape)
        # t = time.time()
yanyan's avatar
yanyan committed
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
        out_tensor = input.shadow_copy()
        if input.benchmark:
            if self.name is None:
                raise ValueError(
                    "you need to assign name to spmodules before benchmark (spconv.utils.bench.assign_name_to_spmod)"
                )
            if self.name not in input.benchmark_record:
                input.benchmark_record[self.name] = {
                    "type": "SparseConvolution",
                    "indice_gen_time": [],
                    "time": [],
                    "num_points": [],
                    "num_out_points": [],
                    "params": {
                        "kernel_size": self.kernel_size,
                        "stride": self.stride,
                        "padding": self.padding,
                        "dilation": self.dilation,
                        "output_padding": self.output_padding,
                        "subm": self.subm,
                        "transposed": self.transposed,
                        "input_channels": self.in_channels,
                        "out_channels": self.out_channels,
                    }
                }
traveller59's avatar
traveller59 committed
279
        if self.conv1x1:
yan.yan's avatar
yan.yan committed
280
281
282
283
284
285
286
            if FILTER_HWIO:
                features = torch.mm(
                    input.features,
                    self.weight.view(self.out_channels, self.in_channels).T)
            else:
                features = torch.mm(
                    input.features,
yan.yan's avatar
yan.yan committed
287
                    self.weight.view(self.in_channels, self.out_channels))
yan.yan's avatar
yan.yan committed
288

traveller59's avatar
fix #17  
traveller59 committed
289
            if self.bias is not None:
traveller59's avatar
traveller59 committed
290
                features += self.bias
yan.yan's avatar
yan.yan committed
291
            out_tensor = out_tensor.replace_feature(features)
292
293
            # padding may change spatial shape of conv 1x1.
            out_tensor.spatial_shape = out_spatial_shape
traveller59's avatar
traveller59 committed
294
            return out_tensor
yan.yan's avatar
v2.1  
yan.yan committed
295
        indice_dict = input.indice_dict.copy()
296

yan.yan's avatar
v2.1  
yan.yan committed
297
        algo = self.algo
298
        if self.indice_key is not None:
yan.yan's avatar
v2.1  
yan.yan committed
299
300
301
302
303
            datas = input.find_indice_pair(self.indice_key)
            if datas is not None:
                msg = "due to limitation of pytorch, you must provide same algo to layers share same indice key."
                assert algo == datas.algo, msg
                # algo = datas.algo
304
305
306
307
308
309
310
311
312
313
314
        profile_ctx = nullcontext()
        if input._timer is not None and self._sparse_unique_name:
            profile_ctx = input._timer.namespace(self._sparse_unique_name)
        with profile_ctx:
            if algo == ConvAlgo.Native:
                datas = input.find_indice_pair(self.indice_key)
                if datas is not None:
                    assert isinstance(datas, IndiceData)
                if self.inverse:
                    assert datas is not None and self.indice_key is not None
                    assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
yan.yan's avatar
v2.1  
yan.yan committed
315

316
                    outids = datas.indices
yan.yan's avatar
v2.1  
yan.yan committed
317
318
                    indice_pairs = datas.indice_pairs
                    indice_pair_num = datas.indice_pair_num
319
                    out_spatial_shape = datas.spatial_shape
yan.yan's avatar
yan.yan committed
320
                    assert datas.ksize == self.kernel_size, "inverse conv must have same kernel size as its couple conv"
yan.yan's avatar
v2.1  
yan.yan committed
321
                else:
322
323
324
325
                    if self.indice_key is not None and datas is not None:
                        outids = datas.out_indices
                        indice_pairs = datas.indice_pairs
                        indice_pair_num = datas.indice_pair_num
yan.yan's avatar
yan.yan committed
326
                        assert self.subm, "only support reuse subm indices"
yan.yan's avatar
yan.yan committed
327
328
                        self._check_subm_reuse_valid(input, spatial_shape,
                                                     datas)
329
330
331
332
                    else:
                        if input.benchmark:
                            torch.cuda.synchronize()
                            t = time.time()
333
334
335
336
337
338
339
340
341
342
343
344
345
346
                        try:
                            outids, indice_pairs, indice_pair_num = ops.get_indice_pairs(
                                indices, batch_size, spatial_shape, algo,
                                self.kernel_size, self.stride, self.padding,
                                self.dilation, self.output_padding, self.subm,
                                self.transposed)
                        except Exception as e:
                            msg = "[Exception|native_pair]"
                            msg += f"indices={indices.shape},bs={batch_size},ss={spatial_shape},"
                            msg += f"algo={algo},ksize={self.kernel_size},stride={self.stride},"
                            msg += f"padding={self.padding},dilation={self.dilation},subm={self.subm},"
                            msg += f"transpose={self.transposed}"
                            print(msg, file=sys.stderr)
                            spconv_save_debug_data(indices)
yan.yan's avatar
yan.yan committed
347
                            raise e
348
349
350
351
352
                        if input.benchmark:
                            torch.cuda.synchronize()
                            interval = time.time() - t
                            out_tensor.benchmark_record[
                                self.name]["indice_gen_time"].append(interval)
yan.yan's avatar
v2.1  
yan.yan committed
353

354
355
356
357
358
                        indice_data = IndiceData(outids,
                                                 indices,
                                                 indice_pairs,
                                                 indice_pair_num,
                                                 spatial_shape,
359
                                                 out_spatial_shape,
360
                                                 is_subm=self.subm,
361
362
363
364
365
                                                 algo=algo,
                                                 ksize=self.kernel_size,
                                                 stride=self.stride,
                                                 padding=self.padding,
                                                 dilation=self.dilation)
366
367
368
369
370
371
372
373
374
375
376
377
                        if self.indice_key is not None:
                            msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
                            assert self.indice_key not in indice_dict, msg
                            indice_dict[self.indice_key] = indice_data
                if input.benchmark:
                    torch.cuda.synchronize()
                    t = time.time()
                indice_pairs_calc = indice_pairs
                if indice_pairs.device != features.device:
                    indice_pairs_calc = indice_pairs.to(features.device)
                if self.subm:
                    out_features = Fsp.indice_subm_conv(
yan.yan's avatar
v2.1  
yan.yan committed
378
                        features, self.weight, indice_pairs_calc,
379
                        indice_pair_num, outids.shape[0], algo, input._timer)
traveller59's avatar
traveller59 committed
380
                else:
381
382
383
384
385
386
387
388
389
390
                    if self.inverse:
                        out_features = Fsp.indice_inverse_conv(
                            features, self.weight, indice_pairs_calc,
                            indice_pair_num, outids.shape[0], algo)
                    else:
                        out_features = Fsp.indice_conv(features, self.weight,
                                                       indice_pairs_calc,
                                                       indice_pair_num,
                                                       outids.shape[0], algo,
                                                       input._timer)
yan.yan's avatar
v2.1  
yan.yan committed
391
392

            else:
393
394
395
396
397
398
399
400
401
402
403
404
405
                datas = input.find_indice_pair(self.indice_key)
                if datas is not None:
                    assert isinstance(datas, ImplicitGemmIndiceData)
                if self.inverse:
                    assert datas is not None and self.indice_key is not None
                    assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
                    outids = datas.indices
                    pair_fwd = datas.pair_bwd
                    pair_bwd = datas.pair_fwd
                    pair_mask_fwd_splits = datas.pair_mask_bwd_splits
                    pair_mask_bwd_splits = datas.pair_mask_fwd_splits
                    mask_argsort_fwd_splits = datas.mask_argsort_bwd_splits
                    mask_argsort_bwd_splits = datas.mask_argsort_fwd_splits
yan.yan's avatar
v2.1  
yan.yan committed
406
                    masks = datas.masks
407
                    out_spatial_shape = datas.spatial_shape
408
                    assert datas.ksize == self.kernel_size, "inverse conv must have same kernel size as its couple conv"
yan.yan's avatar
v2.1  
yan.yan committed
409
                else:
410
411
412
413
414
415
416
417
418
                    if self.indice_key is not None and datas is not None:
                        outids = datas.out_indices
                        pair_fwd = datas.pair_fwd
                        pair_bwd = datas.pair_bwd
                        pair_mask_fwd_splits = datas.pair_mask_fwd_splits
                        pair_mask_bwd_splits = datas.pair_mask_bwd_splits
                        mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits
                        mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits
                        masks = datas.masks
yan.yan's avatar
yan.yan committed
419
                        assert self.subm, "only support reuse subm indices"
yan.yan's avatar
yan.yan committed
420
421
                        self._check_subm_reuse_valid(input, spatial_shape,
                                                     datas)
422
                    else:
423

424
                        with input._timer.namespace("gen_pairs"):
425
426
                            # we need to gen bwd indices for regular conv
                            # because it may be inversed.
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
                            try:
                                res = ops.get_indice_pairs_implicit_gemm(
                                    indices,
                                    batch_size,
                                    spatial_shape,
                                    algo,
                                    ksize=self.kernel_size,
                                    stride=self.stride,
                                    padding=self.padding,
                                    dilation=self.dilation,
                                    out_padding=self.output_padding,
                                    subm=self.subm,
                                    transpose=self.transposed,
                                    is_train=(not self.subm) or self.training,
                                    alloc=input.thrust_allocator,
                                    timer=input._timer)
                            except Exception as e:
                                msg = "[Exception|implicit_gemm_pair]"
                                msg += f"indices={indices.shape},bs={batch_size},ss={spatial_shape},"
                                msg += f"algo={algo},ksize={self.kernel_size},stride={self.stride},"
                                msg += f"padding={self.padding},dilation={self.dilation},subm={self.subm},"
                                msg += f"transpose={self.transposed}"
                                print(msg, file=sys.stderr)
                                spconv_save_debug_data(indices)
yan.yan's avatar
yan.yan committed
451
                                raise e
452

453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
                        outids = res[0]
                        num_inds_per_loc = res[1]
                        pair_fwd = res[2]
                        pair_bwd = res[3]
                        pair_mask_fwd_splits = res[4]
                        pair_mask_bwd_splits = res[5]
                        mask_argsort_fwd_splits = res[6]
                        mask_argsort_bwd_splits = res[7]
                        masks = res[8]
                        if self.indice_key is not None:
                            indice_data = ImplicitGemmIndiceData(
                                outids,
                                indices,
                                pair_fwd,
                                pair_bwd,
                                pair_mask_fwd_splits=pair_mask_fwd_splits,
                                pair_mask_bwd_splits=pair_mask_bwd_splits,
                                mask_argsort_fwd_splits=mask_argsort_fwd_splits,
                                mask_argsort_bwd_splits=mask_argsort_bwd_splits,
                                masks=masks,
                                is_subm=self.subm,
474
                                spatial_shape=spatial_shape,
475
                                out_spatial_shape=out_spatial_shape,
476
477
478
479
480
                                algo=algo,
                                ksize=self.kernel_size,
                                stride=self.stride,
                                padding=self.padding,
                                dilation=self.dilation)
481
482
483
484
485
486
487
488
489
490
491
492
                            msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
                            assert self.indice_key not in indice_dict, msg
                            indice_dict[self.indice_key] = indice_data
                if input.benchmark:
                    torch.cuda.synchronize()
                    t = time.time()
                num_activate_out = outids.shape[0]
                out_features = Fsp.implicit_gemm(
                    features, self.weight, pair_fwd, pair_bwd,
                    pair_mask_fwd_splits, pair_mask_bwd_splits,
                    mask_argsort_fwd_splits, mask_argsort_bwd_splits,
                    num_activate_out, masks, self.training, self.subm,
yan.yan's avatar
yan.yan committed
493
                    input._timer, self.fp32_accum)
yan.yan's avatar
v2.1  
yan.yan committed
494
495
        if self.bias is not None:
            out_features += self.bias
yanyan's avatar
yanyan committed
496
497
498
499
500
501
502
503
        if input.benchmark:
            torch.cuda.synchronize()
            interval = time.time() - t
            out_tensor.benchmark_record[self.name]["time"].append(interval)
            out_tensor.benchmark_record[self.name]["num_points"].append(
                features.shape[0])
            out_tensor.benchmark_record[self.name]["num_out_points"].append(
                out_features.shape[0])
yan.yan's avatar
yan.yan committed
504
        out_tensor = out_tensor.replace_feature(out_features)
yanyan's avatar
yanyan committed
505
        out_tensor.indices = outids
yan.yan's avatar
v2.1  
yan.yan committed
506
        out_tensor.indice_dict = indice_dict
yanyan's avatar
yanyan committed
507
        out_tensor.spatial_shape = out_spatial_shape
traveller59's avatar
traveller59 committed
508
509
        return out_tensor

yan.yan's avatar
yan.yan committed
510
511
512
513
    def _check_subm_reuse_valid(self, inp: SparseConvTensor,
                                spatial_shape: List[int],
                                datas: Union[ImplicitGemmIndiceData,
                                             IndiceData]):
yan.yan's avatar
yan.yan committed
514
515
        assert datas.is_subm, "only support reuse subm indices"
        if self.kernel_size != datas.ksize:
yan.yan's avatar
yan.yan committed
516
517
            raise ValueError(
                f"subm with same indice_key must have same kernel"
yan.yan's avatar
yan.yan committed
518
519
                f" size, expect {datas.ksize}, this layer {self.kernel_size}")
        if self.dilation != datas.dilation:
yan.yan's avatar
yan.yan committed
520
521
            raise ValueError(
                f"subm with same indice_key must have same dilation"
yan.yan's avatar
yan.yan committed
522
523
                f", expect {datas.dilation}, this layer {self.dilation}")
        if inp.spatial_shape != datas.spatial_shape:
yan.yan's avatar
yan.yan committed
524
525
            raise ValueError(
                f"subm with same indice_key must have same spatial structure"
yan.yan's avatar
yan.yan committed
526
527
                f", expect {datas.spatial_shape}, input {spatial_shape}")
        if inp.indices.shape[0] != datas.indices.shape[0]:
yan.yan's avatar
yan.yan committed
528
529
530
531
            raise ValueError(
                f"subm with same indice_key must have same num of indices"
                f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}"
            )
yan.yan's avatar
yan.yan committed
532
533


yan.yan's avatar
yan.yan committed
534
535
536
537
538
539
540
541
542
543
544
class SparseConv1d(SparseConvolution):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 indice_key=None,
yan.yan's avatar
v2.1  
yan.yan committed
545
                 algo: Optional[ConvAlgo] = None,
yan.yan's avatar
yan.yan committed
546
                 fp32_accum: Optional[bool] = None,
yan.yan's avatar
yan.yan committed
547
548
549
550
551
552
553
554
555
556
557
558
                 name=None):
        super(SparseConv1d, self).__init__(1,
                                           in_channels,
                                           out_channels,
                                           kernel_size,
                                           stride,
                                           padding,
                                           dilation,
                                           groups,
                                           bias,
                                           indice_key=indice_key,
                                           algo=algo,
yan.yan's avatar
yan.yan committed
559
                                           fp32_accum=fp32_accum,
yan.yan's avatar
yan.yan committed
560
561
                                           name=name)

traveller59's avatar
traveller59 committed
562
563
564
565
566
567
568
569
570
571
572

class SparseConv2d(SparseConvolution):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
573
                 indice_key=None,
yan.yan's avatar
v2.1  
yan.yan committed
574
                 algo: Optional[ConvAlgo] = None,
yan.yan's avatar
yan.yan committed
575
                 fp32_accum: Optional[bool] = None,
yanyan's avatar
yanyan committed
576
                 name=None):
577
578
579
580
581
582
583
584
585
586
        super(SparseConv2d, self).__init__(2,
                                           in_channels,
                                           out_channels,
                                           kernel_size,
                                           stride,
                                           padding,
                                           dilation,
                                           groups,
                                           bias,
                                           indice_key=indice_key,
yanyan's avatar
yanyan committed
587
                                           algo=algo,
yan.yan's avatar
yan.yan committed
588
                                           fp32_accum=fp32_accum,
yanyan's avatar
yanyan committed
589
                                           name=name)
traveller59's avatar
traveller59 committed
590
591
592
593
594
595
596
597
598
599
600
601


class SparseConv3d(SparseConvolution):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
602
                 indice_key=None,
yan.yan's avatar
v2.1  
yan.yan committed
603
                 algo: Optional[ConvAlgo] = None,
yan.yan's avatar
yan.yan committed
604
                 fp32_accum: Optional[bool] = None,
yanyan's avatar
yanyan committed
605
                 name=None):
606
607
608
609
610
611
612
613
614
615
        super(SparseConv3d, self).__init__(3,
                                           in_channels,
                                           out_channels,
                                           kernel_size,
                                           stride,
                                           padding,
                                           dilation,
                                           groups,
                                           bias,
                                           indice_key=indice_key,
yanyan's avatar
yanyan committed
616
                                           algo=algo,
yan.yan's avatar
yan.yan committed
617
                                           fp32_accum=fp32_accum,
yanyan's avatar
yanyan committed
618
                                           name=name)
619

traveller59's avatar
traveller59 committed
620

traveller59's avatar
traveller59 committed
621
622
623
624
625
626
627
628
629
630
class SparseConv4d(SparseConvolution):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
631
                 indice_key=None,
yan.yan's avatar
v2.1  
yan.yan committed
632
                 algo: Optional[ConvAlgo] = None,
yan.yan's avatar
yan.yan committed
633
                 fp32_accum: Optional[bool] = None,
yanyan's avatar
yanyan committed
634
                 name=None):
635
636
637
638
639
640
641
642
643
644
        super(SparseConv4d, self).__init__(4,
                                           in_channels,
                                           out_channels,
                                           kernel_size,
                                           stride,
                                           padding,
                                           dilation,
                                           groups,
                                           bias,
                                           indice_key=indice_key,
yanyan's avatar
yanyan committed
645
                                           algo=algo,
yan.yan's avatar
yan.yan committed
646
                                           fp32_accum=fp32_accum,
yanyan's avatar
yanyan committed
647
                                           name=name)
traveller59's avatar
traveller59 committed
648
649


yan.yan's avatar
bug fix  
yan.yan committed
650
651
652
653
654
655
656
657
658
659
660
class SparseConvTranspose1d(SparseConvolution):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 indice_key=None,
yan.yan's avatar
v2.1  
yan.yan committed
661
                 algo: Optional[ConvAlgo] = None,
yan.yan's avatar
yan.yan committed
662
                 fp32_accum: Optional[bool] = None,
yan.yan's avatar
bug fix  
yan.yan committed
663
664
665
666
667
668
669
670
671
672
673
674
675
                 name=None):
        super(SparseConvTranspose1d, self).__init__(1,
                                                    in_channels,
                                                    out_channels,
                                                    kernel_size,
                                                    stride,
                                                    padding,
                                                    dilation,
                                                    groups,
                                                    bias,
                                                    transposed=True,
                                                    indice_key=indice_key,
                                                    algo=algo,
yan.yan's avatar
yan.yan committed
676
                                                    fp32_accum=fp32_accum,
yan.yan's avatar
bug fix  
yan.yan committed
677
678
679
                                                    name=name)


traveller59's avatar
traveller59 committed
680
681
682
683
684
685
686
687
688
689
class SparseConvTranspose2d(SparseConvolution):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
690
                 indice_key=None,
yan.yan's avatar
v2.1  
yan.yan committed
691
                 algo: Optional[ConvAlgo] = None,
yan.yan's avatar
yan.yan committed
692
                 fp32_accum: Optional[bool] = None,
yanyan's avatar
yanyan committed
693
                 name=None):
694
695
696
697
698
699
700
701
702
703
704
        super(SparseConvTranspose2d, self).__init__(2,
                                                    in_channels,
                                                    out_channels,
                                                    kernel_size,
                                                    stride,
                                                    padding,
                                                    dilation,
                                                    groups,
                                                    bias,
                                                    transposed=True,
                                                    indice_key=indice_key,
yanyan's avatar
yanyan committed
705
                                                    algo=algo,
yan.yan's avatar
yan.yan committed
706
                                                    fp32_accum=fp32_accum,
yanyan's avatar
yanyan committed
707
                                                    name=name)
traveller59's avatar
traveller59 committed
708
709
710
711
712
713
714
715
716
717
718
719


class SparseConvTranspose3d(SparseConvolution):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
720
                 indice_key=None,
yan.yan's avatar
v2.1  
yan.yan committed
721
                 algo: Optional[ConvAlgo] = None,
yan.yan's avatar
yan.yan committed
722
                 fp32_accum: Optional[bool] = None,
yanyan's avatar
yanyan committed
723
                 name=None):
724
725
726
727
728
729
730
731
732
733
734
        super(SparseConvTranspose3d, self).__init__(3,
                                                    in_channels,
                                                    out_channels,
                                                    kernel_size,
                                                    stride,
                                                    padding,
                                                    dilation,
                                                    groups,
                                                    bias,
                                                    transposed=True,
                                                    indice_key=indice_key,
yanyan's avatar
yanyan committed
735
                                                    algo=algo,
yan.yan's avatar
yan.yan committed
736
                                                    fp32_accum=fp32_accum,
yanyan's avatar
yanyan committed
737
                                                    name=name)
traveller59's avatar
traveller59 committed
738

yan.yan's avatar
v2.1  
yan.yan committed
739

yan.yan's avatar
bug fix  
yan.yan committed
740
741
742
743
744
745
746
747
748
749
750
class SparseConvTranspose4d(SparseConvolution):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 indice_key=None,
yan.yan's avatar
v2.1  
yan.yan committed
751
                 algo: Optional[ConvAlgo] = None,
yan.yan's avatar
yan.yan committed
752
                 fp32_accum: Optional[bool] = None,
yan.yan's avatar
bug fix  
yan.yan committed
753
754
755
756
757
758
759
760
761
762
763
764
765
                 name=None):
        super(SparseConvTranspose4d, self).__init__(4,
                                                    in_channels,
                                                    out_channels,
                                                    kernel_size,
                                                    stride,
                                                    padding,
                                                    dilation,
                                                    groups,
                                                    bias,
                                                    transposed=True,
                                                    indice_key=indice_key,
                                                    algo=algo,
yan.yan's avatar
yan.yan committed
766
                                                    fp32_accum=fp32_accum,
yan.yan's avatar
bug fix  
yan.yan committed
767
768
769
                                                    name=name)


yan.yan's avatar
yan.yan committed
770
771
772
773
774
775
776
class SparseInverseConv1d(SparseConvolution):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 indice_key,
                 bias=True,
yan.yan's avatar
v2.1  
yan.yan committed
777
                 algo: Optional[ConvAlgo] = None,
yan.yan's avatar
yan.yan committed
778
                 fp32_accum: Optional[bool] = None,
yan.yan's avatar
yan.yan committed
779
780
781
782
783
784
785
786
787
                 name=None):
        super(SparseInverseConv1d, self).__init__(1,
                                                  in_channels,
                                                  out_channels,
                                                  kernel_size,
                                                  bias=bias,
                                                  inverse=True,
                                                  indice_key=indice_key,
                                                  algo=algo,
yan.yan's avatar
yan.yan committed
788
                                                  fp32_accum=fp32_accum,
yan.yan's avatar
yan.yan committed
789
790
                                                  name=name)

traveller59's avatar
traveller59 committed
791

traveller59's avatar
traveller59 committed
792
793
794
795
class SparseInverseConv2d(SparseConvolution):
    def __init__(self,
                 in_channels,
                 out_channels,
796
                 kernel_size,
traveller59's avatar
traveller59 committed
797
                 indice_key,
Yan Yan's avatar
Yan Yan committed
798
                 bias=True,
yan.yan's avatar
v2.1  
yan.yan committed
799
                 algo: Optional[ConvAlgo] = None,
yan.yan's avatar
yan.yan committed
800
                 fp32_accum: Optional[bool] = None,
yanyan's avatar
yanyan committed
801
                 name=None):
802
803
804
805
806
807
        super(SparseInverseConv2d, self).__init__(2,
                                                  in_channels,
                                                  out_channels,
                                                  kernel_size,
                                                  bias=bias,
                                                  inverse=True,
Yan Yan's avatar
Yan Yan committed
808
                                                  indice_key=indice_key,
yanyan's avatar
yanyan committed
809
                                                  algo=algo,
yan.yan's avatar
yan.yan committed
810
                                                  fp32_accum=fp32_accum,
yanyan's avatar
yanyan committed
811
                                                  name=name)
traveller59's avatar
traveller59 committed
812
813
814
815
816
817


class SparseInverseConv3d(SparseConvolution):
    def __init__(self,
                 in_channels,
                 out_channels,
818
                 kernel_size,
traveller59's avatar
traveller59 committed
819
                 indice_key,
Yan Yan's avatar
Yan Yan committed
820
                 bias=True,
yan.yan's avatar
v2.1  
yan.yan committed
821
                 algo: Optional[ConvAlgo] = None,
yan.yan's avatar
yan.yan committed
822
                 fp32_accum: Optional[bool] = None,
yanyan's avatar
yanyan committed
823
                 name=None):
824
825
826
827
828
829
        super(SparseInverseConv3d, self).__init__(3,
                                                  in_channels,
                                                  out_channels,
                                                  kernel_size,
                                                  bias=bias,
                                                  inverse=True,
Yan Yan's avatar
Yan Yan committed
830
                                                  indice_key=indice_key,
yanyan's avatar
yanyan committed
831
                                                  algo=algo,
yan.yan's avatar
yan.yan committed
832
                                                  fp32_accum=fp32_accum,
yanyan's avatar
yanyan committed
833
                                                  name=name)
traveller59's avatar
traveller59 committed
834

yan.yan's avatar
v2.1  
yan.yan committed
835

yan.yan's avatar
yan.yan committed
836
837
838
839
840
841
842
class SparseInverseConv4d(SparseConvolution):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 indice_key,
                 bias=True,
yan.yan's avatar
v2.1  
yan.yan committed
843
                 algo: Optional[ConvAlgo] = None,
yan.yan's avatar
yan.yan committed
844
                 fp32_accum: Optional[bool] = None,
yan.yan's avatar
yan.yan committed
845
846
847
848
849
850
851
852
853
                 name=None):
        super(SparseInverseConv4d, self).__init__(4,
                                                  in_channels,
                                                  out_channels,
                                                  kernel_size,
                                                  bias=bias,
                                                  inverse=True,
                                                  indice_key=indice_key,
                                                  algo=algo,
yan.yan's avatar
yan.yan committed
854
                                                  fp32_accum=fp32_accum,
yan.yan's avatar
yan.yan committed
855
856
                                                  name=name)

yan.yan's avatar
v2.1  
yan.yan committed
857

yan.yan's avatar
yan.yan committed
858
859
860
861
862
863
864
865
866
867
868
class SubMConv1d(SparseConvolution):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 indice_key=None,
yan.yan's avatar
v2.1  
yan.yan committed
869
                 algo: Optional[ConvAlgo] = None,
yan.yan's avatar
yan.yan committed
870
                 fp32_accum: Optional[bool] = None,
yan.yan's avatar
yan.yan committed
871
872
873
874
875
876
877
878
879
880
881
882
883
                 name=None):
        super(SubMConv1d, self).__init__(1,
                                         in_channels,
                                         out_channels,
                                         kernel_size,
                                         stride,
                                         padding,
                                         dilation,
                                         groups,
                                         bias,
                                         True,
                                         indice_key=indice_key,
                                         algo=algo,
yan.yan's avatar
yan.yan committed
884
                                         fp32_accum=fp32_accum,
yan.yan's avatar
yan.yan committed
885
886
                                         name=name)

traveller59's avatar
traveller59 committed
887
888
889
890
891
892
893
894
895
896
897

class SubMConv2d(SparseConvolution):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
898
                 indice_key=None,
yan.yan's avatar
v2.1  
yan.yan committed
899
                 algo: Optional[ConvAlgo] = None,
yan.yan's avatar
yan.yan committed
900
                 fp32_accum: Optional[bool] = None,
yanyan's avatar
yanyan committed
901
                 name=None):
902
903
904
905
906
907
908
909
910
911
912
        super(SubMConv2d, self).__init__(2,
                                         in_channels,
                                         out_channels,
                                         kernel_size,
                                         stride,
                                         padding,
                                         dilation,
                                         groups,
                                         bias,
                                         True,
                                         indice_key=indice_key,
yanyan's avatar
yanyan committed
913
                                         algo=algo,
yan.yan's avatar
yan.yan committed
914
                                         fp32_accum=fp32_accum,
yanyan's avatar
yanyan committed
915
                                         name=name)
traveller59's avatar
traveller59 committed
916
917
918
919
920
921
922
923
924
925
926
927


class SubMConv3d(SparseConvolution):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
928
                 indice_key=None,
yan.yan's avatar
v2.1  
yan.yan committed
929
                 algo: Optional[ConvAlgo] = None,
yan.yan's avatar
yan.yan committed
930
                 fp32_accum: Optional[bool] = None,
yanyan's avatar
yanyan committed
931
                 name=None):
932
933
934
935
936
937
938
939
940
941
942
        super(SubMConv3d, self).__init__(3,
                                         in_channels,
                                         out_channels,
                                         kernel_size,
                                         stride,
                                         padding,
                                         dilation,
                                         groups,
                                         bias,
                                         True,
                                         indice_key=indice_key,
yanyan's avatar
yanyan committed
943
                                         algo=algo,
yan.yan's avatar
yan.yan committed
944
                                         fp32_accum=fp32_accum,
yanyan's avatar
yanyan committed
945
                                         name=name)
946

traveller59's avatar
traveller59 committed
947
948
949
950
951
952
953
954
955
956
957

class SubMConv4d(SparseConvolution):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
958
                 indice_key=None,
yan.yan's avatar
v2.1  
yan.yan committed
959
                 algo: Optional[ConvAlgo] = None,
yan.yan's avatar
yan.yan committed
960
                 fp32_accum: Optional[bool] = None,
yanyan's avatar
yanyan committed
961
                 name=None):
962
963
964
965
966
967
968
969
970
971
972
        super(SubMConv4d, self).__init__(4,
                                         in_channels,
                                         out_channels,
                                         kernel_size,
                                         stride,
                                         padding,
                                         dilation,
                                         groups,
                                         bias,
                                         True,
                                         indice_key=indice_key,
yanyan's avatar
yanyan committed
973
                                         algo=algo,
yan.yan's avatar
yan.yan committed
974
                                         fp32_accum=fp32_accum,
yanyan's avatar
yanyan committed
975
                                         name=name)