cppcore.py 13 KB
Newer Older
yan.yan's avatar
yan.yan committed
1
# Copyright 2021 Yan Yan
2
#
yan.yan's avatar
yan.yan committed
3
4
5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
yan.yan's avatar
yan.yan committed
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
yan.yan's avatar
yan.yan committed
9
10
11
12
13
14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
16
from cumm import tensorview as tv
import torch
yan.yan's avatar
sync  
yan.yan committed
17
from typing import Dict, Optional, List, Union
yan.yan's avatar
yan.yan committed
18
from spconv.constants import AllocKeys
yan.yan's avatar
yan.yan committed
19
from spconv.cppconstants import COMPILED_CUDA_ARCHS
yan.yan's avatar
yan.yan committed
20
import sys
yan.yan's avatar
sync  
yan.yan committed
21
from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
yan.yan's avatar
yan.yan committed
22
23
24
from spconv.core_cc.csrc.sparse.convops import ExternalSpconvMatmul

import numpy as np
25

yan.yan's avatar
yan.yan committed
26
27
28
29
30
31
32
33
34
35
36
_TORCH_DTYPE_TO_TV = {
    torch.float32: tv.float32,
    torch.float64: tv.float64,
    torch.float16: tv.float16,
    torch.int32: tv.int32,
    torch.int64: tv.int64,
    torch.int8: tv.int8,
    torch.int16: tv.int16,
    torch.uint8: tv.uint8,
}

yan.yan's avatar
yan.yan committed
37
38
39
40
41
_TORCH_UINT_WORKAROUNDS = {
    tv.uint32: tv.int32,
    tv.uint16: tv.int16,
    tv.uint64: tv.int64
}
yan.yan's avatar
yan.yan committed
42
43
44
45
46
47
48
49
50

_TV_DTYPE_TO_TORCH = {v: k for k, v in _TORCH_DTYPE_TO_TV.items()}
_TV_DTYPE_TO_TORCH.update({
    tv.uint32: torch.int32,
    tv.uint16: torch.int16,
    tv.uint64: torch.int64

})

yan.yan's avatar
yan.yan committed
51
52
53
54
55
_ALL_INTS = {
    tv.int32, tv.int16, tv.int8, tv.int64, tv.uint64, tv.uint8, tv.uint32,
    tv.uint16
}

56
57
58

def torch_tensor_to_tv(ten: torch.Tensor,
                       dtype: Optional[int] = None,
yan.yan's avatar
yan.yan committed
59
60
                       shape: Optional[List[int]] = None,
                       stride: Optional[List[int]] = None):
yan.yan's avatar
yan.yan committed
61
    # assert ten.is_contiguous(), "must be contiguous tensor"
yan.yan's avatar
yan.yan committed
62
    ptr = ten.data_ptr()
63
    device = ten.device
yan.yan's avatar
yan.yan committed
64
65
66
67
68
69
70
71
    if device.type == "cpu":
        tv_device = -1
    elif device.type == "cuda":
        tv_device = 0
    else:
        raise NotImplementedError
    if dtype is None:
        dtype = _TORCH_DTYPE_TO_TV[ten.dtype]
yan.yan's avatar
yan.yan committed
72
73
74
75
76
77
78
79
80
81
82
83
    if stride is None:
        stride = list(ten.stride())
    if shape is None:
        shape = list(ten.shape)
    else:
        if not ten.is_contiguous():
            msg = "if you provide custom shape for non-contig tensor, stride must not None"
            assert stride is not None, msg
        else:
            # custom shape, if tensor is contiguous, we use from_blob and calc strides
            return tv.from_blob(ptr, shape, dtype, tv_device)
    return tv.from_blob_strided(ptr, shape, stride, dtype, tv_device)
yan.yan's avatar
yan.yan committed
84

yan.yan's avatar
yan.yan committed
85

86
87
88
def torch_tensors_to_tv(*tens: torch.Tensor):
    return (torch_tensor_to_tv(t) for t in tens)

89

yan.yan's avatar
yan.yan committed
90
91
92
def get_current_stream():
    return torch.cuda.current_stream().cuda_stream

yan.yan's avatar
yan.yan committed
93

yan.yan's avatar
yan.yan committed
94
95
96
def get_arch():
    arch = torch.cuda.get_device_capability()
    if arch not in COMPILED_CUDA_ARCHS:
yan.yan's avatar
yan.yan committed
97
98
99
100
101
        print(
            f"[WARNING]your gpu arch {arch} isn't compiled in prebuilt, "
            f"may cause invalid device function. "
            f"available: {COMPILED_CUDA_ARCHS}",
            file=sys.stderr)
yan.yan's avatar
yan.yan committed
102
    return arch
yan.yan's avatar
sync  
yan.yan committed
103

yan.yan's avatar
yan.yan committed
104

yan.yan's avatar
sync  
yan.yan committed
105
class TorchAllocator(ExternalAllocator):
yan.yan's avatar
yan.yan committed
106

yan.yan's avatar
sync  
yan.yan committed
107
108
109
    def __init__(self, gpudevice: torch.device) -> None:
        super().__init__()
        self.gpudevice = gpudevice
yan.yan's avatar
yan.yan committed
110
        self.cpudevice = torch.device("cpu")
yan.yan's avatar
sync  
yan.yan committed
111
112
        self.allocated: Dict[Union[str, int], torch.Tensor] = {}

yan.yan's avatar
yan.yan committed
113
    def zeros(self, name: str, shape: List[int], dtype: int,
Yan Yan's avatar
Yan Yan committed
114
              device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
yan.yan's avatar
yan.yan committed
115
        # TODO free memory by name if its already free by pointer.
yan.yan's avatar
sync  
yan.yan committed
116
117
118
119
120
121
122
123
        # provide a name if you want to access it after c++ function exit.
        dtype_bkp = dtype
        th_dtype = _TV_DTYPE_TO_TORCH[dtype]
        if device == -1:
            dev = self.cpudevice
        else:
            dev = self.gpudevice
        ten = torch.zeros(shape, dtype=th_dtype, device=dev)
yan.yan's avatar
yan.yan committed
124
125
        ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
        self.allocated[ten_tv.byte_pointer()] = ten
yan.yan's avatar
yan.yan committed
126
        if name and not is_temp_memory:
yan.yan's avatar
sync  
yan.yan committed
127
128
129
            self.allocated[name] = ten
        return ten_tv

yan.yan's avatar
yan.yan committed
130
    def empty(self, name: str, shape: List[int], dtype: int,
Yan Yan's avatar
Yan Yan committed
131
              device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
yan.yan's avatar
sync  
yan.yan committed
132
133
134
135
136
137
138
        dtype_bkp = dtype
        th_dtype = _TV_DTYPE_TO_TORCH[dtype]
        if device == -1:
            dev = self.cpudevice
        else:
            dev = self.gpudevice
        ten = torch.empty(shape, dtype=th_dtype, device=dev)
yan.yan's avatar
yan.yan committed
139
140
        ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
        self.allocated[ten_tv.byte_pointer()] = ten
yan.yan's avatar
yan.yan committed
141
        if name and not is_temp_memory:
yan.yan's avatar
sync  
yan.yan committed
142
143
144
            self.allocated[name] = ten
        return ten_tv

yan.yan's avatar
yan.yan committed
145
    def full_int(self, name: str, shape: List[int], value: int, dtype: int,
Yan Yan's avatar
Yan Yan committed
146
                 device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
yan.yan's avatar
sync  
yan.yan committed
147
148
149
150
151
152
153
154
155
        if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
            raise NotImplementedError("you can't use full for unsigned dtypes")
        dtype_bkp = dtype
        th_dtype = _TV_DTYPE_TO_TORCH[dtype]
        if device == -1:
            dev = self.cpudevice
        else:
            dev = self.gpudevice
        ten = torch.full(shape, value, dtype=th_dtype, device=dev)
yan.yan's avatar
yan.yan committed
156
157
        ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
        self.allocated[ten_tv.byte_pointer()] = ten
yan.yan's avatar
yan.yan committed
158
        if name and not is_temp_memory:
yan.yan's avatar
sync  
yan.yan committed
159
160
161
            self.allocated[name] = ten
        return ten_tv

yan.yan's avatar
yan.yan committed
162
    def full_float(self, name: str, shape: List[int], value: float, dtype: int,
Yan Yan's avatar
Yan Yan committed
163
                   device: int, stream: int = 0, is_temp_memory: bool = False) -> tv.Tensor:
yan.yan's avatar
sync  
yan.yan committed
164
165
166
167
168
169
170
171
172
        if dtype in _TORCH_UINT_WORKAROUNDS and value < 0:
            raise NotImplementedError("you can't use full for unsigned dtypes")
        dtype_bkp = dtype
        th_dtype = _TV_DTYPE_TO_TORCH[dtype]
        if device == -1:
            dev = self.cpudevice
        else:
            dev = self.gpudevice
        ten = torch.full(shape, value, dtype=th_dtype, device=dev)
yan.yan's avatar
yan.yan committed
173
174
        ten_tv = torch_tensor_to_tv(ten, dtype_bkp)
        self.allocated[ten_tv.byte_pointer()] = ten
yan.yan's avatar
yan.yan committed
175
        if name and not is_temp_memory:
yan.yan's avatar
sync  
yan.yan committed
176
177
178
            self.allocated[name] = ten
        return ten_tv

yan.yan's avatar
yan.yan committed
179
180
181
    def get_tensor_by_name(self, name: str):
        return torch_tensor_to_tv(self.allocated[name])

yan.yan's avatar
sync  
yan.yan committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    def free(self, ten: tv.Tensor):
        if ten.storage_bytesize() != ten.bytesize():
            raise ValueError("you can't free a sliced tensor.")
        if ten.byte_pointer() in self.allocated:
            self.allocated.pop(ten.byte_pointer())
            return
        raise ValueError("can't find your tensor in cache.")

    def free_noexcept(self, ten: tv.Tensor):
        # for c++ scope guard, free will be called in c++ destructor
        if ten.storage_bytesize() != ten.bytesize():
            return
        if ten.byte_pointer() in self.allocated:
            self.allocated.pop(ten.byte_pointer())
            return


yan.yan's avatar
yan.yan committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
class TorchSpconvMatmul(ExternalSpconvMatmul):

    def __init__(self, alloc: TorchAllocator) -> None:
        super().__init__()
        self.alloc = alloc

    def indice_conv_init_gemm(self, features_n: str, filters_n: str,
                              all_weight_is_krsc: bool, is_kc_not_ck: bool,
                              kv_center: int, out_channel: int, stream_int: int = 0):
        features = self.alloc.allocated[features_n]
        filters = self.alloc.allocated[filters_n]
        if not all_weight_is_krsc:
            filters = filters.reshape(-1, *filters.shape[-2:])
            if not is_kc_not_ck:
                out_features = torch.mm(features, filters[kv_center])
            else:
                out_features = torch.mm(features, filters[kv_center].T)
        else:
            filters = filters.reshape(out_channel, -1, filters.shape[-1])
            if features.is_cuda or (features.dtype != torch.float16):
                out_features = torch.mm(features, filters[:, kv_center].T)
            else:
                # pytorch 1.12 don't support cpu half mm, f**k pytorch
                # we need cpu fp16 mm for test only.
                out_features = torch.empty((features.shape[0], out_channel),
                                           dtype=features.dtype,
                                           device=features.device)
                features_np = torch_tensor_to_tv(features).numpy_view()
                filters_np = torch_tensor_to_tv(filters).numpy_view()
                out_features_np = torch_tensor_to_tv(out_features).numpy_view()
                np.matmul(features_np,
                          filters_np[:, kv_center].T,
                          out=out_features_np)
        self.alloc.allocated[AllocKeys.OutFeatures] = out_features
        # print(filters.shape, features.shape, all_weight_is_krsc, out_features.shape, out_features.is_contiguous())

        return torch_tensor_to_tv(out_features)

    def indice_conv_cpu_gemm(self, inp_buffer_n: str, out_buffer_n: str, filters_n: str,
                             all_weight_is_krsc: bool,
                             is_kc_not_ck: bool, nhot: int, index: int):
        kv_dim = 1 if all_weight_is_krsc else 0
        inp_buffer = self.alloc.allocated[inp_buffer_n]
        filters = self.alloc.allocated[filters_n]
        if not all_weight_is_krsc:
            filters = filters.reshape(-1, *filters.shape[-2:])
        else:
            filters = filters.reshape(filters.shape[0], -1, filters.shape[-1])
        out_buffer = self.alloc.allocated[out_buffer_n]
        filters_i = filters.select(kv_dim, index)
        filters_cur = filters_i if not is_kc_not_ck else filters_i.T
        if inp_buffer.dtype == torch.float16:
            inp_buffer_np = torch_tensor_to_tv(inp_buffer).numpy_view()
            filters_np = torch_tensor_to_tv(filters).numpy_view()
            filters_i_np = filters_np[
                index] if not all_weight_is_krsc else filters_np[:, index]
            filters_cur_np = filters_i_np if not is_kc_not_ck else filters_i_np.T
            out_buffer_np = torch_tensor_to_tv(out_buffer).numpy_view()
            np.matmul(inp_buffer_np[:nhot],
                      filters_cur_np,
                      out=out_buffer_np[:nhot])
        else:
            torch.mm(inp_buffer[:nhot], filters_cur, out=out_buffer[:nhot])

    def indice_conv_bwd_init_gemm(self, features_n: str, filters_n: str,
                                  out_bp_n: str, dfilters_n: str,
                                  all_weight_is_krsc: bool, is_kc_not_ck: bool,
                                  kv_center: int, stream_int: int = 0):
        features = self.alloc.allocated[features_n]
        filters = self.alloc.allocated[filters_n]
        out_bp = self.alloc.allocated[out_bp_n]
        dfilters = self.alloc.allocated[dfilters_n]
        if not all_weight_is_krsc:
            filters = filters.reshape(-1, *filters.shape[-2:])
            dfilters = dfilters.reshape(-1, *filters.shape[-2:])

        else:
            filters = filters.reshape(filters.shape[0], -1, filters.shape[-1])
            dfilters = dfilters.reshape(filters.shape[0], -1, filters.shape[-1])

        if not all_weight_is_krsc:
            if not is_kc_not_ck:
                torch.mm(features.T, out_bp, out=dfilters[kv_center])
                din = torch.mm(out_bp, filters[kv_center].T)
            else:
                torch.mm(out_bp.T, features, out=dfilters[kv_center])
                din = torch.mm(out_bp, filters[kv_center])
        else:
            # KN @ NC
            torch.mm(out_bp.T, features, out=dfilters[:, kv_center])
            # NK @ KC
            din = torch.mm(out_bp, filters[:, kv_center])
        self.alloc.allocated[AllocKeys.DIn] = din
        return torch_tensor_to_tv(din)

    def indice_conv_bwd_cpu_gemm(self, inp_buffer_n: str, 
                             out_buffer_n: str, filters_n: str, dfilters_n: str,all_weight_is_krsc: bool,
                             is_kc_not_ck: bool, nhot: int, index: int):
        kv_dim = 1 if all_weight_is_krsc else 0
        inp_buffer = self.alloc.allocated[inp_buffer_n]
        out_buffer = self.alloc.allocated[out_buffer_n]
        filters = self.alloc.allocated[filters_n]
        dfilters = self.alloc.allocated[dfilters_n]
        if not all_weight_is_krsc:
            filters = filters.reshape(-1, *filters.shape[-2:])
            dfilters = dfilters.reshape(-1, *filters.shape[-2:])

        else:
            filters = filters.reshape(filters.shape[0], -1, filters.shape[-1])
            dfilters = dfilters.reshape(filters.shape[0], -1, filters.shape[-1])

        filters_i = filters.select(kv_dim, index)
        dfilters_i = dfilters.select(kv_dim, index)

        filters_KC = filters_i if is_kc_not_ck else filters_i.T
        if is_kc_not_ck:
            # KN @ NC
            torch.mm(out_buffer[:nhot].T, inp_buffer[:nhot], out=dfilters_i)
        else:
            # CN @ NK
            torch.mm(inp_buffer[:nhot].T, out_buffer[:nhot], out=dfilters_i)
        # NK @ KC
        torch.mm(out_buffer[:nhot], filters_KC, out=inp_buffer[:nhot])

yan.yan's avatar
yan.yan committed
323
324
325
if __name__ == "__main__":
    a = torch.rand(2, 2)
    atv = torch_tensor_to_tv(a)
326
    print(atv.numpy_view())