dev_subm.py 14.9 KB
Newer Older
yan.yan's avatar
v2.1  
yan.yan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import sys
from pathlib import Path
from typing import Dict, List, Tuple
import pickle
import sys
import time
from pathlib import Path
from cumm.gemm.algospec.core import GemmAlgo

import numpy as np
import pccm
import torch
import torch.nn.functional as F

from cumm import dtypes
from cumm import tensorview as tv
from cumm.constants import PACKAGE_ROOT
from cumm.conv.bases import NCHW, NHWC, ConvIterAlgo, ConvOpType
from cumm.conv.main import ConvMainUnitTest, gen_gemm_kernels
from cumm.conv.params import ConvProblem
from cumm.gemm import kernel
22
import os
yan.yan's avatar
v2.1  
yan.yan committed
23
24
25
26
27
from spconv.core_cc.csrc.sparse.all import SpconvOps
from cumm.gemm.codeops import div_up
from spconv.constants import PACKAGE_ROOT
from spconv.core import ConvAlgo

28
from spconv.pytorch import ops
yan.yan's avatar
v2.1  
yan.yan committed
29
30
31
from spconv.algo import CONV, BestConvAlgoByProfile
from spconv.pytorch.cppcore import torch_tensor_to_tv

32

yan.yan's avatar
v2.1  
yan.yan committed
33
34
35
def reduce_mask_count(mask: np.ndarray, width: int):
    mask_length_32 = (div_up(mask.shape[0], width)) * width
    if mask.shape[0] < mask_length_32:
36
        mask_pad = np.zeros((mask_length_32, ), dtype=mask.dtype)
yan.yan's avatar
v2.1  
yan.yan committed
37
38
39
40
41
42
43
        mask_pad[:mask.shape[0]] = mask
        mask = mask_pad
    mask = mask.reshape(-1, width)
    maskr = np.bitwise_or.reduce(mask, axis=1)
    maskr_tv = tv.from_numpy(maskr)
    return SpconvOps.count_bits(maskr_tv).numpy().sum() * width

44

yan.yan's avatar
v2.1  
yan.yan committed
45
46
47
def reduce_mask_count_x(mask: np.ndarray, width: int):
    mask_length_32 = (div_up(mask.shape[0], width)) * width
    if mask.shape[0] < mask_length_32:
48
        mask_pad = np.zeros((mask_length_32, ), dtype=mask.dtype)
yan.yan's avatar
v2.1  
yan.yan committed
49
50
51
52
53
54
        mask_pad[:mask.shape[0]] = mask
        mask = mask_pad
    mask = mask.reshape(-1, width)
    maskr = np.bitwise_or.reduce(mask, axis=1)
    return maskr

55

yan.yan's avatar
v2.1  
yan.yan committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True):
    limit_input_n = 16384
    limit_input_n = None
    np.random.seed(484)

    with (PACKAGE_ROOT.parent / "test/data/test_spconv.pkl").open("rb") as f:
        voxels_np, indices_np, spatial_shape = pickle.load(f)
        from spconv.test_utils import generate_sparse_data
        voxels_np = voxels_np[:limit_input_n]
        indices_np = indices_np[:limit_input_n]

        spatial_shape = [19, 18, 17]
        sparse_dict = generate_sparse_data(spatial_shape, [1024], 128)

        voxels_np = np.ascontiguousarray(sparse_dict["features"]).astype(
            np.float32)
        indices_np = np.ascontiguousarray(
            sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32)

        voxels = tv.from_numpy(voxels_np).cuda()
        indices = tv.from_numpy(indices_np).cuda()
        indices_th = torch.from_numpy(indices_np).cuda()
    print(spatial_shape, indices_np.shape)
    ndim = 3
    if subm:
        ksize = [3, 3, 3]
        kv = np.prod(ksize)
        padding = [1] * ndim
        stride = [1] * ndim
        dilation = [1] * ndim
        out_padding = [0] * ndim
    else:
        ksize = [2, 2, 2]
        kv = np.prod(ksize)
        padding = [0] * ndim
        stride = [1] * ndim
        dilation = [1] * ndim
        out_padding = [0] * ndim
94
95
96
    out_inds, pair_ref, indice_num_per_loc = ops.get_indice_pairs(
        indices_th, 1, spatial_shape, ConvAlgo.Native, ksize, stride, padding,
        dilation, out_padding, subm)
yan.yan's avatar
v2.1  
yan.yan committed
97
98
99
100
101
102
103
104
    indice_num_per_loc_np = indice_num_per_loc.cpu().numpy()
    indice_pairs_np = pair_ref.cpu().numpy()
    algo = ConvAlgo.MaskSplitImplicitGemm
    if algo == ConvAlgo.MaskImplicitGemm:
        num_split = 1
    else:
        num_split = 2
    for i in range(5):
105
106
107
        res = ops.get_indice_pairs_implicit_gemm(indices_th, 1, spatial_shape,
                                                 algo, ksize, stride, padding,
                                                 dilation, out_padding, subm)
yan.yan's avatar
v2.1  
yan.yan committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    out_inds = res[0]
    num_inds_per_loc = res[1]
    pair_fwd = res[2]
    pair_fwd_x = pair_fwd.cpu().numpy().reshape(-1)
    pair_fwd_x[pair_fwd_x == -1] = 0
    loc_num_np = (pair_fwd_x > 0).reshape(kv, -1).sum(1)
    print(loc_num_np)
    print(indice_num_per_loc_np)

    pair_bwd = res[3]
    pair_mask_fwd_splits = res[4]
    pair_mask_bwd_splits = res[5]
    mask_argsort_fwd_splits = res[6]
    mask_argsort_bwd_splits = res[7]
    masks = res[8]
123
124
125
126
127
128
129
130
131
    pair_mask_fwd_splits_tv = [
        ops.torch_tensor_to_tv(t, dtype=tv.uint32)
        for t in pair_mask_fwd_splits
    ]
    valid_location_bitcount = [
        SpconvOps.count_bits(t) for t in pair_mask_fwd_splits_tv
    ]
    valid_location_count = sum(
        [t.cpu().numpy().sum() for t in valid_location_bitcount])
yan.yan's avatar
v2.1  
yan.yan committed
132
    reduce_length = 32
133
134
135
136
    split_mask_valid_count = sum([
        reduce_mask_count(t.cpu().numpy(), reduce_length)
        for t in pair_mask_fwd_splits_tv
    ])
yan.yan's avatar
v2.1  
yan.yan committed
137
    if subm:
138
139
        print("SUBM", valid_location_count, split_mask_valid_count,
              pair_fwd.numel())
yan.yan's avatar
v2.1  
yan.yan committed
140
    else:
141
142
143
        print("REGULAR", valid_location_count, split_mask_valid_count,
              pair_fwd.numel())
    # return
yan.yan's avatar
v2.1  
yan.yan committed
144
145
146
147
148

    if run_conv:
        C = 64
        K = 64
        desps = CONV.desps
149
150
151
152
153
154
        mask_output_fwd = torch.zeros([2, div_up(out_inds.shape[0], 32)],
                                      dtype=torch.int32,
                                      device=indices_th.device)
        mask_output_bwd = torch.zeros([2, div_up(indices.dim(0), 32)],
                                      dtype=torch.int32,
                                      device=indices_th.device)
yan.yan's avatar
v2.1  
yan.yan committed
155
156
157
158
159
160
161
162

        for desp in desps:
            if desp.algo != GemmAlgo.Simt.value:
                continue
            # if desp.op_type == ConvOpType.kBackwardWeight.value:
            #     continue
            # if desp.tile_shape !
            if desp.dtype_a == dtypes.int8.tv_dtype:
163
164
165
166
167
168
169
                inp = np.random.randint(-1, 1, size=[voxels_np.shape[0],
                                                     C]).astype(np.int8)
                weight = np.random.randint(-1, 1, size=[K, *ksize,
                                                        C]).astype(np.int8)
                output = np.random.randint(-1, 1, size=[
                    out_inds.shape[0], K
                ]).astype(dtypes.get_npdtype_from_tvdtype(desp.dtype_output))
yan.yan's avatar
v2.1  
yan.yan committed
170
            else:
171
172
173
                inp = np.random.uniform(-1, 1, size=[
                    voxels_np.shape[0], C
                ]).astype(dtypes.get_npdtype_from_tvdtype(desp.dtype_input))
yan.yan's avatar
v2.1  
yan.yan committed
174
175
                weight = np.random.uniform(-1, 1, size=[K, *ksize, C]).astype(
                    dtypes.get_npdtype_from_tvdtype(desp.dtype_weight))
176
177
178
                output = np.random.uniform(-1, 1, size=[
                    out_inds.shape[0], K
                ]).astype(dtypes.get_npdtype_from_tvdtype(desp.dtype_output))
yan.yan's avatar
v2.1  
yan.yan committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
            weight_ref = weight.transpose(1, 2, 3, 0, 4)
            weight_ref = np.ascontiguousarray(weight_ref).reshape(-1, K, C)
            if desp.op_type == ConvOpType.kBackwardInput.value:
                inp_tv = tv.zeros(inp.shape, desp.dtype_input, 0)
            else:
                inp_tv = tv.from_numpy(inp).cuda()
            if desp.op_type == ConvOpType.kBackwardWeight.value:
                weight_tv = tv.zeros(weight.shape, desp.dtype_weight, 0)
            else:
                weight_tv = tv.from_numpy(weight).cuda()
            # _ = tv.zeros([5000, 10], tv.float32, 0)
            if desp.op_type == ConvOpType.kForward.value:
                output_tv = tv.zeros(output.shape, desp.dtype_output, 0)
            else:
                output_tv = tv.from_numpy(output).cuda()
            torch.cuda.synchronize()
            t = time.time()
            spk = 1
            if desp.op_type == ConvOpType.kBackwardWeight.value:
                # TODO support splitk parallel
                spk = 32
            if subm:
                if desp.op_type == ConvOpType.kForward.value:
                    indice_pairs = pair_fwd
                elif desp.op_type == ConvOpType.kBackwardInput.value:
                    indice_pairs = pair_bwd
                else:
                    indice_pairs = pair_fwd
                mask_output = mask_output_fwd
                # print([bin(x.item()) for x in masks])
                for j in range(num_split):
                    beta = 1 if j == 1 else 0
                    mask_filter = 0xffffffff
                    mask_filter = masks[j].item()

                    reverse_mask = False
                    if desp.op_type == ConvOpType.kBackwardWeight.value:
                        mask_op = mask_output[j]
                    else:
                        mask_op = pair_mask_fwd_splits[j]
                    if desp.op_type == ConvOpType.kBackwardInput.value:
                        reverse_mask = True
                    CONV.run_with_tuned_result(
                        BestConvAlgoByProfile(desp, spk),
                        desp.op_type,
                        inp_tv,
                        weight_tv,
                        output_tv,
                        torch_tensor_to_tv(mask_op, dtype=tv.uint32),
                        torch_tensor_to_tv(mask_argsort_fwd_splits[j]),
                        torch_tensor_to_tv(mask_output[j], dtype=tv.uint32),
                        torch_tensor_to_tv(indice_pairs),
                        reverse_mask,
                        mask_filter=mask_filter,
                        mask_width=32,
                        beta=beta,
                        verbose=True,
                    )
            else:
                if desp.op_type == ConvOpType.kForward.value:
239
                    indice_pairs = pair_fwd  # inp -> out
yan.yan's avatar
v2.1  
yan.yan committed
240
241
242
243
                    mask_ops = pair_mask_fwd_splits
                    mask_argsorts = mask_argsort_fwd_splits
                    mask_output = mask_output_fwd
                elif desp.op_type == ConvOpType.kBackwardInput.value:
244
                    indice_pairs = pair_bwd  # out -> inp
yan.yan's avatar
v2.1  
yan.yan committed
245
246
247
248
249
250
                    mask_ops = pair_mask_bwd_splits
                    mask_argsorts = mask_argsort_bwd_splits
                    mask_output = mask_output_bwd

                    print([bin(x.item()) for x in masks])
                else:
251
                    indice_pairs = pair_fwd  # inp -> out
yan.yan's avatar
v2.1  
yan.yan committed
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
                    mask_ops = pair_mask_fwd_splits
                    mask_argsorts = mask_argsort_fwd_splits
                    mask_output = mask_output_fwd

                for j in range(2):
                    beta = 1 if j == 1 else 0
                    mask_filter = masks[j].item()
                    reverse_mask = False
                    if desp.op_type == ConvOpType.kBackwardWeight.value:
                        mask_op = mask_output[j]
                    else:
                        mask_op = mask_ops[j]

                    CONV.run_with_tuned_result(
                        BestConvAlgoByProfile(desp, spk),
                        desp.op_type,
                        inp_tv,
                        weight_tv,
                        output_tv,
                        torch_tensor_to_tv(mask_op, dtype=tv.uint32),
                        torch_tensor_to_tv(mask_argsorts[j]),
                        torch_tensor_to_tv(mask_output[j], dtype=tv.uint32),
                        torch_tensor_to_tv(indice_pairs),
                        reverse_mask,
                        mask_filter=mask_filter,
                        mask_width=32,
                        beta=beta,
                        verbose=True,
                    )

            torch.cuda.synchronize()
283
            duration = time.time() - t
yan.yan's avatar
v2.1  
yan.yan committed
284
285
286
287
288
289
290
291
292
293
294
295
296
297
            if desp.op_type == ConvOpType.kForward.value:
                output_ref = np.zeros_like(output, dtype=np.float32)
                # ref algorithm
                for filter_offset in range(kv):
                    if subm and filter_offset > kv // 2:
                        nhot = indice_num_per_loc_np[kv - 1 - filter_offset]
                    elif subm and filter_offset == kv // 2:
                        nhot = voxels.shape[0]
                    else:
                        nhot = indice_num_per_loc_np[filter_offset]
                    a_inds = indice_pairs_np[0][filter_offset][:nhot]
                    c_inds = indice_pairs_np[1][filter_offset][:nhot]
                    # print(a_inds_cpu[:10])
                    a = inp[a_inds]
298
299
300
                    cc = a.astype(
                        np.float32) @ weight_ref[filter_offset].T.astype(
                            np.float32)
yan.yan's avatar
v2.1  
yan.yan committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
                    output_ref[c_inds] += cc

                output_cpu = output_tv.cpu().numpy().astype(np.float32)
                duration = time.time() - t
                my = output_cpu.reshape(-1)
                print("ERROR", np.linalg.norm(output_ref.reshape(-1) - my))

            elif desp.op_type == ConvOpType.kBackwardInput.value:
                dinput_ref = np.zeros_like(inp, dtype=np.float32)
                # ref algorithm
                for filter_offset in range(kv):
                    if subm and filter_offset > kv // 2:
                        nhot = indice_num_per_loc_np[kv - 1 - filter_offset]
                    elif subm and filter_offset == kv // 2:
                        nhot = voxels.shape[0]
                    else:
                        nhot = indice_num_per_loc_np[filter_offset]
                    a_inds = indice_pairs_np[1][filter_offset][:nhot]
                    c_inds = indice_pairs_np[0][filter_offset][:nhot]

                    # print(a_inds_cpu[:10])
                    a = output[a_inds]
                    # NK @ KC
324
325
326
                    cc = a.astype(
                        np.float32) @ weight_ref[filter_offset].astype(
                            np.float32)
yan.yan's avatar
v2.1  
yan.yan committed
327
328
                    dinput_ref[c_inds] += cc
                din_cpu = inp_tv.cpu().numpy()
329
330
331
332
                print(
                    "ERROR",
                    np.linalg.norm(
                        din_cpu.reshape(-1) - dinput_ref.reshape(-1)))
yan.yan's avatar
v2.1  
yan.yan committed
333
            else:
334
335
                dw_ref = np.zeros_like(weight_ref,
                                       dtype=np.float32)  # KV, K, C
yan.yan's avatar
v2.1  
yan.yan committed
336
337
338
339
340
341
342
343
344
345
                for filter_offset in range(kv):
                    if subm and filter_offset > kv // 2:
                        nhot = indice_num_per_loc_np[kv - 1 - filter_offset]
                    elif subm and filter_offset == kv // 2:
                        nhot = voxels.shape[0]
                    else:
                        nhot = indice_num_per_loc_np[filter_offset]
                    o_inds = indice_pairs_np[1][filter_offset][:nhot]
                    i_inds = indice_pairs_np[0][filter_offset][:nhot]
                    # print(a_inds_cpu[:10])
346
347
                    out_gather = output[o_inds]  # [N, K]
                    inp_gather = inp[i_inds]  # [N, C]
yan.yan's avatar
v2.1  
yan.yan committed
348
                    # KN @ NC
349
350
                    dw_res = out_gather.astype(
                        np.float32).T @ inp_gather.astype(np.float32)
yan.yan's avatar
v2.1  
yan.yan committed
351
352
353
354
355
                    dw_ref[filter_offset] = dw_res
                # print(indice_pairs_np_test[0])
                dw_ref_kcrs = dw_ref.transpose(1, 0, 2)
                dw_cpu = weight_tv.cpu().numpy().reshape(K, np.prod(ksize), C)

356
357
358
359
                print(
                    "ERROR",
                    np.linalg.norm(
                        dw_cpu.reshape(-1) - dw_ref_kcrs.reshape(-1)))
yan.yan's avatar
v2.1  
yan.yan committed
360
361
362
363


if __name__ == "__main__":
    dev_subm_inds_v2()