"wrappers/python/src/vscode:/vscode.git/clone" did not exist on "5fa6fbc194344f34ba444dd4bdf4d7ab7f89db29"
Commit 7af751dc authored by yan.yan's avatar yan.yan
Browse files

sync

parent 647927ce
# Changelog # Changelog
## [2.1.22] - 2022-4-14
### Added
- add full nvrtc support
- add support for large spatial shape and batch size. if detect large shape, we use int64 instead of int32 when hashing.
## [2.1.21] - 2021-12-9 ## [2.1.21] - 2021-12-9
### Added ### Added
- add sm_37 - add sm_37
......
...@@ -56,7 +56,7 @@ def main(): ...@@ -56,7 +56,7 @@ def main():
is_empty = table.insert_exist_keys(keys, values) is_empty = table.insert_exist_keys(keys, values)
ks, vs, cnt = table.items() ks, vs, cnt = table.items()
cnt_item = cnt.item() cnt_item = cnt.item()
print(cnt, ks[:cnt_item], vs[:cnt_item]) print(cnt, ks[:cnt_item], vs[:cnt_item], is_empty.dtype)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -12,28 +12,50 @@ ...@@ -12,28 +12,50 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from enum import Enum import contextlib
from cumm import tensorview as tv
from typing import Dict, List, Set, Tuple, Union
from spconv.core_cc.cumm.gemm.main import GemmAlgoDesp, GemmMainUnitTest, GemmParams
from spconv.core_cc.cumm.conv.main import ConvAlgoDesp, ConvMainUnitTest, ConvParams
from cumm.conv.bases import ConvLayout, ConvLayoutType, ConvOpType
from cumm.gemm.algospec.core import GemmAlgo, ShuffleStrideType, get_min_arch_of_algo_str, get_available_algo_str_from_arch
from cumm.gemm.codeops import group_by, div_up
from spconv.constants import NDIM_DONT_CARE, SPCONV_BWD_SPLITK
from typing import Optional
import time import time
from enum import Enum
from threading import Lock from threading import Lock
import contextlib from typing import Dict, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
from spconv.core import ConvAlgo, AlgoHint from cumm import tensorview as tv
from cumm.conv.bases import ConvLayout, ConvLayoutType, ConvOpType
from cumm.conv.kernel import ConvKernel
from cumm.gemm.kernel import GemmKernel
from cumm.gemm.algospec.core import (GemmAlgo, ShuffleStrideType,
get_available_algo_str_from_arch,
get_min_arch_of_algo_str)
from cumm.gemm.codeops import div_up, group_by
from cumm.nvrtc import CummNVRTCModule, get_cudadevrt_path
from cumm.tensorview.gemm import ConvAlgoDesp
from cumm.tensorview.gemm import ConvOpType as ConvOpTypeCpp
from cumm.tensorview.gemm import ConvParams, GemmAlgoDesp, GemmParams
from cumm import dtypes
from spconv.constants import (NDIM_DONT_CARE, SPCONV_BWD_SPLITK,
SPCONV_NVRTC_MODE, SPCONV_DEBUG_NVRTC_KERNELS)
from spconv.core import ALL_IMPGEMM_PARAMS, AlgoHint, ConvAlgo
from spconv.core_cc.cumm.conv.main import ConvMainUnitTest
from spconv.core_cc.cumm.gemm.main import GemmMainUnitTest
from spconv.cppconstants import COMPILED_CUDA_ARCHS
from cumm.tensorview.gemm import NVRTCParams
from spconv.tools import CUDAKernelTimer from spconv.tools import CUDAKernelTimer
from cumm.gemm.constants import NVRTCConstants, NVRTCMode
from spconv import algocore
from cumm.conv.main import gen_gemm_kernels as gen_conv_kernels
from cumm.gemm.main import gen_gemm_kernels
ALL_ALGO_DESPS = GemmMainUnitTest.get_all_algo_desp() ALL_ALGO_DESPS = GemmMainUnitTest.get_all_algo_desp()
ALL_CONV_ALGO_DESPS = ConvMainUnitTest.get_all_conv_algo_desp() ALL_CONV_ALGO_DESPS = ConvMainUnitTest.get_all_conv_algo_desp()
_GEMM_STATIC_KEY = Tuple[bool, bool, bool, int, int, int, str, str] _GEMM_STATIC_KEY = Tuple[bool, bool, bool, int, int, int, str, str]
class SimpleGemmAlgoMeta: class SimpleGemmAlgoMeta:
def __init__(self, tile_ms: List[int], tile_ns: List[int], def __init__(self, tile_ms: List[int], tile_ns: List[int],
tile_ks: List[int], tile_ks: List[int],
...@@ -45,22 +67,68 @@ class SimpleGemmAlgoMeta: ...@@ -45,22 +67,68 @@ class SimpleGemmAlgoMeta:
class BestAlgoByProfile: class BestAlgoByProfile:
def __init__(self, algo_desp: GemmAlgoDesp, splitk: int = 1) -> None: def __init__(self, algo_desp: GemmAlgoDesp, arch: Tuple[int, int], splitk: int = 1) -> None:
self.algo_desp = algo_desp self.algo_desp = algo_desp
self.splitk = splitk self.splitk = splitk
self.arch = arch
class BestConvAlgoByProfile: class BestConvAlgoByProfile:
def __init__(self, algo_desp: ConvAlgoDesp, splitk: int = 1) -> None: def __init__(self, algo_desp: ConvAlgoDesp, arch: Tuple[int, int], splitk: int = 1) -> None:
self.algo_desp = algo_desp self.algo_desp = algo_desp
self.splitk = splitk self.splitk = splitk
self.arch = arch
def _get_nvrtc_params(mod: CummNVRTCModule, ker: Union[GemmKernel, ConvKernel], kernel_name: str):
nvrtc_mode = SPCONV_NVRTC_MODE
nvrtc_params = tv.gemm.NVRTCParams()
nvrtc_params.cumodule = mod.get_cpp_object()
nvrtc_params.mode = nvrtc_mode.value
nvrtc_params.num_threads = ker.num_threads
nvrtc_params.smem_size = ker.smem_size
ns = ker.namespace
if nvrtc_mode == NVRTCMode.DynamicParallism:
nvrtc_params.kernel_name = mod.get_lowered_name(
f"{ns}::nvrtc_kernel")
elif nvrtc_mode == NVRTCMode.KernelAndCPU:
nvrtc_params.kernel_name = mod.get_lowered_name(f"{ns}::{kernel_name}")
nvrtc_params.init_kernel_name = mod.get_lowered_name(
f"{ns}::nvrtc_kernel_cpu_out")
nvrtc_params.param_size = mod.const_values[
f"{ns}::{NVRTCConstants.SIZEOF_KEY}"]
nvrtc_params.param_storage = tv.empty([nvrtc_params.param_size],
tv.uint8, 0)
nvrtc_params.param_storage_cpu = tv.empty(
[nvrtc_params.param_size], tv.uint8, -1, pinned=True)
elif nvrtc_mode == NVRTCMode.Direct:
nvrtc_params.kernel_name = mod.get_lowered_name(f"{ns}::{kernel_name}")
elif nvrtc_mode == NVRTCMode.ConstantMemory:
nvrtc_params.kernel_name = mod.get_lowered_name(f"{ns}::{kernel_name}")
nvrtc_params.init_kernel_name = mod.get_lowered_name(
f"{ns}::nvrtc_kernel_cpu_out")
nvrtc_params.param_size = mod.const_values[
f"{ns}::{NVRTCConstants.SIZEOF_KEY}"]
nvrtc_params.constant_name = mod.get_lowered_name(
f"&{ns}::{NVRTCConstants.CONSTANT_PARAM_KEY}")
nvrtc_params.param_storage = tv.empty([nvrtc_params.param_size],
tv.uint8, 0)
else:
raise NotImplementedError
return nvrtc_params
class SimpleGemm: class SimpleGemm:
def __init__(self, desps: List[GemmAlgoDesp]) -> None: def __init__(self, prebuilt_desps: List[GemmAlgoDesp]) -> None:
self.desps = desps all_desps = [algocore.get_conv_algo_desp_from_param(p) for p in ALL_IMPGEMM_PARAMS]
self.prebuilt_desps = prebuilt_desps
self.prebuilt_desp_names = {str(d) for d in prebuilt_desps}
if SPCONV_DEBUG_NVRTC_KERNELS:
self.prebuilt_desp_names.clear()
self.lock = Lock() self.lock = Lock()
self.static_key_to_desps = group_by(self.get_static_key, desps) self.static_key_to_desps = group_by(self.get_static_key, all_desps)
self.static_key_to_meta: Dict[_GEMM_STATIC_KEY, self.static_key_to_meta: Dict[_GEMM_STATIC_KEY,
SimpleGemmAlgoMeta] = {} SimpleGemmAlgoMeta] = {}
for k, static_desps in self.static_key_to_desps.items(): for k, static_desps in self.static_key_to_desps.items():
...@@ -94,15 +162,44 @@ class SimpleGemm: ...@@ -94,15 +162,44 @@ class SimpleGemm:
self.mn_cache: Dict[Tuple[int, int, int, int, int], self.mn_cache: Dict[Tuple[int, int, int, int, int],
BestAlgoByProfile] = {} # for backward weight BestAlgoByProfile] = {} # for backward weight
self._nvrtc_caches: Dict[Tuple[str, Tuple[int, int]], NVRTCParams] = {}
@staticmethod @staticmethod
def get_static_key(d: GemmAlgoDesp) -> _GEMM_STATIC_KEY: def get_static_key(d: GemmAlgoDesp) -> _GEMM_STATIC_KEY:
return (d.trans_a, d.trans_b, d.trans_c, d.dtype_a, d.dtype_b, return (d.trans_a, d.trans_b, d.trans_c, d.dtype_a, d.dtype_b,
d.dtype_c, d.shuffle_type, d.algo) d.dtype_c, d.shuffle_type.value, d.algo)
def device_synchronize(self): def device_synchronize(self):
return GemmMainUnitTest.device_synchronize() return GemmMainUnitTest.device_synchronize()
def _compile_nvrtc_module(self, desp: GemmAlgoDesp):
params = algocore.get_gemm_param_from_desp(desp)
kernel = gen_gemm_kernels(params, SPCONV_NVRTC_MODE)
kernel.namespace = "spconv"
custom_names = []
if SPCONV_NVRTC_MODE == NVRTCMode.ConstantMemory:
custom_names = [f"&{kernel.namespace}::{NVRTCConstants.CONSTANT_PARAM_KEY}"]
cudadevrt = ""
if SPCONV_NVRTC_MODE == NVRTCMode.DynamicParallism:
cudadevrt_p = get_cudadevrt_path()
assert cudadevrt_p is not None, "DynamicParallism must have cudadevrt"
cudadevrt = str(cudadevrt_p)
mod = CummNVRTCModule([kernel],
cudadevrt_path=cudadevrt,
verbose=False,
custom_names=custom_names)
mod.load()
return mod, kernel
def _cached_get_nvrtc_params(self, desp: GemmAlgoDesp, arch: Tuple[int, int]):
key = (str(desp), arch)
if key in self._nvrtc_caches:
return self._nvrtc_caches[key]
mod, ker = self._compile_nvrtc_module(desp)
nvrtc_params = _get_nvrtc_params(mod, ker, "gemm_kernel")
self._nvrtc_caches[key] = nvrtc_params
return nvrtc_params
def get_all_available( def get_all_available(
self, self,
a: tv.Tensor, a: tv.Tensor,
...@@ -135,6 +232,11 @@ class SimpleGemm: ...@@ -135,6 +232,11 @@ class SimpleGemm:
ldb = b.dim(1) ldb = b.dim(1)
ldc = c.dim(1) ldc = c.dim(1)
if desp.supported_ldx(lda, ldb, ldc): if desp.supported_ldx(lda, ldb, ldc):
if arch not in COMPILED_CUDA_ARCHS:
desp = desp.copy()
desp.is_nvrtc = True
if SPCONV_DEBUG_NVRTC_KERNELS:
desp.is_nvrtc = True
finally_algos.append(desp) finally_algos.append(desp)
return finally_algos return finally_algos
...@@ -334,6 +436,8 @@ class SimpleGemm: ...@@ -334,6 +436,8 @@ class SimpleGemm:
if desp.split_k_serial and hint & AlgoHint.BackwardWeight.value: if desp.split_k_serial and hint & AlgoHint.BackwardWeight.value:
split_k_slices = max(min(32, k // 128), 1) split_k_slices = max(min(32, k // 128), 1)
params = GemmParams() params = GemmParams()
if desp.is_nvrtc and str(desp) not in self.prebuilt_desp_names:
params.nvrtc_params = self._cached_get_nvrtc_params(desp, arch)
params.a = a params.a = a
params.b = b params.b = b
params.c = c_ params.c = c_
...@@ -361,7 +465,7 @@ class SimpleGemm: ...@@ -361,7 +465,7 @@ class SimpleGemm:
times.append(np.mean(this_times[1:])) times.append(np.mean(this_times[1:]))
spk_speeds.append(times[-1]) spk_speeds.append(times[-1])
all_profile_res.append(BestAlgoByProfile(desp, splitk=spk)) all_profile_res.append(BestAlgoByProfile(desp, arch, splitk=spk))
min_time = 1000 min_time = 1000
min_idx = -1 min_idx = -1
...@@ -421,6 +525,9 @@ class SimpleGemm: ...@@ -421,6 +525,9 @@ class SimpleGemm:
if profile_res.splitk > 1: if profile_res.splitk > 1:
split_k_slices = profile_res.splitk split_k_slices = profile_res.splitk
params = GemmParams() params = GemmParams()
if algo_desp.is_nvrtc and str(algo_desp) not in self.prebuilt_desp_names:
params.nvrtc_params = self._cached_get_nvrtc_params(algo_desp, profile_res.arch)
params.a = a params.a = a
params.b = b params.b = b
params.c = c params.c = c
...@@ -461,11 +568,14 @@ _CONV_STATIC_KEY = Tuple[int, int, int, int, int, int, int, int, int, str, int] ...@@ -461,11 +568,14 @@ _CONV_STATIC_KEY = Tuple[int, int, int, int, int, int, int, int, int, str, int]
class SimpleConv: class SimpleConv:
def __init__(self, desps: List[ConvAlgoDesp]) -> None: def __init__(self, prebuilt_desps: List[ConvAlgoDesp]) -> None:
self.desps = desps all_desps = [algocore.get_conv_algo_desp_from_param(p) for p in ALL_IMPGEMM_PARAMS]
self.prebuilt_desps = prebuilt_desps
self.prebuilt_desp_names = {str(d) for d in prebuilt_desps}
self.prebuilt_desp_names.clear()
self.lock = Lock() self.lock = Lock()
self.static_key_to_desps = group_by(self.get_static_key, desps) self.static_key_to_desps = group_by(self.get_static_key, all_desps)
self.static_key_to_meta: Dict[_CONV_STATIC_KEY, self.static_key_to_meta: Dict[_CONV_STATIC_KEY,
SimpleGemmAlgoMeta] = {} SimpleGemmAlgoMeta] = {}
for k, static_desps in self.static_key_to_desps.items(): for k, static_desps in self.static_key_to_desps.items():
...@@ -500,28 +610,36 @@ class SimpleConv: ...@@ -500,28 +610,36 @@ class SimpleConv:
self.kc_wgrad_cache: Dict[Tuple[int, int, int, int, int, int, int, self.kc_wgrad_cache: Dict[Tuple[int, int, int, int, int, int, int,
int], BestConvAlgoByProfile] = { int], BestConvAlgoByProfile] = {
} # for backward weight } # for backward weight
self._nvrtc_caches: Dict[Tuple[str, Tuple[int, int]], NVRTCParams] = {}
@staticmethod @staticmethod
def get_static_key(d: ConvAlgoDesp) -> _CONV_STATIC_KEY: def get_static_key(d: ConvAlgoDesp) -> _CONV_STATIC_KEY:
return (d.layout_i, d.layout_w, d.layout_o, d.interleave_i, return (d.layout_i.value, d.layout_w.value, d.layout_o.value,
d.interleave_w, d.interleave_o, d.dtype_input, d.dtype_weight, d.interleave_i, d.interleave_w, d.interleave_o, d.dtype_input,
d.dtype_output, d.algo, d.op_type) d.dtype_weight, d.dtype_output, d.algo, d.op_type.value)
def device_synchronize(self): def device_synchronize(self):
return GemmMainUnitTest.device_synchronize() return GemmMainUnitTest.device_synchronize()
def get_all_available(self, inp: tv.Tensor, weight: tv.Tensor, def get_all_available(self,
out: tv.Tensor, layout_i: ConvLayout, inp: tv.Tensor,
layout_w: ConvLayout, layout_o: ConvLayout, weight: tv.Tensor,
arch: Tuple[int, int], op_type: ConvOpType, out: tv.Tensor,
mask_width: int, fp32_accum: Optional[bool] = None): layout_i: ConvLayout,
layout_w: ConvLayout,
layout_o: ConvLayout,
arch: Tuple[int, int],
op_type: ConvOpType,
mask_width: int,
fp32_accum: Optional[bool] = None):
avail_algos = get_available_algo_str_from_arch(arch) avail_algos = get_available_algo_str_from_arch(arch)
finally_algos: List[ConvAlgoDesp] = [] finally_algos: List[ConvAlgoDesp] = []
is_fp16 = inp.dtype == tv.float16 and weight.dtype == tv.float16 and out.dtype == tv.float16 is_fp16 = inp.dtype == tv.float16 and weight.dtype == tv.float16 and out.dtype == tv.float16
use_f32_as_accum = False use_f32_as_accum = False
kv = int(np.prod(weight.shape[1:-1])) kv = int(np.prod(weight.shape[1:-1]))
# for 3d conv, if reduce axis is too large, may cause nan during # for 3d conv, if reduce axis is too large, may cause nan during
# forward. # forward.
if is_fp16: if is_fp16:
if fp32_accum is None: if fp32_accum is None:
...@@ -551,7 +669,7 @@ class SimpleConv: ...@@ -551,7 +669,7 @@ class SimpleConv:
if use_f32_as_accum: if use_f32_as_accum:
if desp.dacc == tv.float16: if desp.dacc == tv.float16:
continue continue
ldi = inp.dim(-1) ldi = inp.dim(-1)
ldw = weight.dim(-1) ldw = weight.dim(-1)
ldo = out.dim(-1) ldo = out.dim(-1)
...@@ -560,6 +678,11 @@ class SimpleConv: ...@@ -560,6 +678,11 @@ class SimpleConv:
assert mask_width > 0 assert mask_width > 0
mask_width_valid = mask_width % desp.tile_shape[2] == 0 mask_width_valid = mask_width % desp.tile_shape[2] == 0
if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid: if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid:
if arch not in COMPILED_CUDA_ARCHS:
desp = desp.copy()
desp.is_nvrtc = True
if SPCONV_DEBUG_NVRTC_KERNELS:
desp.is_nvrtc = True
finally_algos.append(desp) finally_algos.append(desp)
return finally_algos return finally_algos
...@@ -592,6 +715,34 @@ class SimpleConv: ...@@ -592,6 +715,34 @@ class SimpleConv:
return desp.query_conv_workspace_size(mnk[0], mnk[1], mnk[2], splitk, return desp.query_conv_workspace_size(mnk[0], mnk[1], mnk[2], splitk,
kv) kv)
def _compile_nvrtc_module(self, desp: ConvAlgoDesp):
params = algocore.get_conv_param_from_desp(desp)
kernel = gen_conv_kernels(params, SPCONV_NVRTC_MODE)
kernel.namespace = "spconv"
custom_names = []
if SPCONV_NVRTC_MODE == NVRTCMode.ConstantMemory:
custom_names = [f"&{kernel.namespace}::{NVRTCConstants.CONSTANT_PARAM_KEY}"]
cudadevrt = ""
if SPCONV_NVRTC_MODE == NVRTCMode.DynamicParallism:
cudadevrt_p = get_cudadevrt_path()
assert cudadevrt_p is not None, "DynamicParallism must have cudadevrt"
cudadevrt = str(cudadevrt_p)
mod = CummNVRTCModule([kernel],
cudadevrt_path=cudadevrt,
verbose=False,
custom_names=custom_names)
mod.load()
return mod, kernel
def _cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int, int]):
key = (str(desp), arch)
if key in self._nvrtc_caches:
return self._nvrtc_caches[key]
mod, ker = self._compile_nvrtc_module(desp)
nvrtc_params = _get_nvrtc_params(mod, ker, "conv_kernel")
self._nvrtc_caches[key] = nvrtc_params
return nvrtc_params
def tune_and_cache(self, def tune_and_cache(self,
op_type: ConvOpType, op_type: ConvOpType,
inp: tv.Tensor, inp: tv.Tensor,
...@@ -613,7 +764,7 @@ class SimpleConv: ...@@ -613,7 +764,7 @@ class SimpleConv:
stream: int = 0, stream: int = 0,
fp32_accum: Optional[bool] = None): fp32_accum: Optional[bool] = None):
avail = self.get_all_available(inp, weight, output, layout_i, layout_w, avail = self.get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, arch, op_type, mask_width, layout_o, arch, op_type, mask_width,
fp32_accum) fp32_accum)
inp = inp.clone() inp = inp.clone()
weight = weight.clone() weight = weight.clone()
...@@ -626,7 +777,10 @@ class SimpleConv: ...@@ -626,7 +777,10 @@ class SimpleConv:
all_profile_res: List[BestConvAlgoByProfile] = [] all_profile_res: List[BestConvAlgoByProfile] = []
for desp in avail: for desp in avail:
# for sparse conv, ndim isn't used, so we just provide a constant value. # for sparse conv, ndim isn't used, so we just provide a constant value.
params = ConvParams(NDIM_DONT_CARE, op_type.value) params = ConvParams(NDIM_DONT_CARE, ConvOpTypeCpp(op_type.value))
if desp.is_nvrtc and str(desp) not in self.prebuilt_desp_names:
params.nvrtc_params = self._cached_get_nvrtc_params(desp, arch)
params.conv_algo_desp = desp params.conv_algo_desp = desp
params.input = inp params.input = inp
params.weight = weight.view([channel_k, -1, channel_c]) params.weight = weight.view([channel_k, -1, channel_c])
...@@ -657,13 +811,16 @@ class SimpleConv: ...@@ -657,13 +811,16 @@ class SimpleConv:
GemmMainUnitTest.stream_synchronize(stream) GemmMainUnitTest.stream_synchronize(stream)
t = time.time() t = time.time()
params.split_k_slices = spk params.split_k_slices = spk
ConvMainUnitTest.implicit_gemm2(params) if desp.is_nvrtc and str(desp) not in self.prebuilt_desp_names:
tv.gemm.run_nvrtc_conv_kernel(params)
else:
ConvMainUnitTest.implicit_gemm2(params)
GemmMainUnitTest.stream_synchronize(stream) GemmMainUnitTest.stream_synchronize(stream)
this_times.append(time.time() - t) this_times.append(time.time() - t)
times.append(np.mean(this_times[1:])) times.append(np.mean(this_times[1:]))
spk_speeds.append(times[-1]) spk_speeds.append(times[-1])
all_profile_res.append(BestConvAlgoByProfile(desp, splitk=spk)) all_profile_res.append(BestConvAlgoByProfile(desp, arch, splitk=spk))
if not all_profile_res: if not all_profile_res:
raise ValueError("can't find suitable algorithm for", op_type) raise ValueError("can't find suitable algorithm for", op_type)
min_time = 1000 min_time = 1000
...@@ -720,7 +877,9 @@ class SimpleConv: ...@@ -720,7 +877,9 @@ class SimpleConv:
op_type_value = op_type op_type_value = op_type
else: else:
op_type_value = op_type.value op_type_value = op_type.value
params = ConvParams(NDIM_DONT_CARE, op_type_value) params = ConvParams(NDIM_DONT_CARE, ConvOpTypeCpp(op_type_value))
if algo_desp.is_nvrtc and str(algo_desp) not in self.prebuilt_desp_names:
params.nvrtc_params = self._cached_get_nvrtc_params(algo_desp, profile_res.arch)
params.conv_algo_desp = profile_res.algo_desp params.conv_algo_desp = profile_res.algo_desp
params.input = inp params.input = inp
params.verbose = verbose params.verbose = verbose
......
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Set, Tuple, Union
from cumm.conv.bases import ConvLayout, ConvLayoutType, ConvOpType
from cumm.gemm.algospec.core import (GemmAlgo, ShuffleStrideType)
from cumm.tensorview.gemm import ConvAlgoDesp
from cumm.tensorview.gemm import ConvIterAlgo as ConvIterAlgoCpp
from cumm.tensorview.gemm import ConvOpType as ConvOpTypeCpp
from cumm.tensorview.gemm import ConvLayoutType as ConvLayoutTypeCpp
from cumm.tensorview.gemm import ShuffleStrideType as ShuffleStrideTypeCpp
from cumm.tensorview.gemm import ConvParams, GemmAlgoDesp, GemmParams
from cumm.gemm.main import GemmAlgoParams
from cumm.conv.main import ConvAlgoParams, ConvIterAlgo
from cumm import dtypes
from cumm.conv.bases import (NCHW, NHWC, ConvIterAlgo, ConvLayout,
ConvLayoutType, ConvMode, ConvOpType)
from cumm.gemm.core import MetaArray
from cumm.gemm.algospec import TensorOp
def _assign_gemm_desp_props(desp: Union[ConvAlgoDesp, GemmAlgoDesp],
p: Union[GemmAlgoParams, ConvAlgoParams]):
desp.dtype_a = p.dtype_a.tv_dtype
desp.dtype_b = p.dtype_a.tv_dtype
desp.dtype_c = p.dtype_a.tv_dtype
desp.dacc = p.dtype_acc.tv_dtype
desp.dcomp = p.dtype_comp.tv_dtype
desp.trans_a = p.trans_a
desp.trans_b = p.trans_b
desp.trans_c = p.trans_c
desp.tile_shape = (p.ts[0], p.ts[1], p.ts[2])
desp.warp_tile_shape = (p.wts[0], p.wts[1], p.wts[2])
if p.tensorop is not None:
desp.tensorop = (p.tensorop[0], p.tensorop[1], p.tensorop[2])
desp.num_stage = p.num_stage
desp.algo = p.algo.value
desp.split_k_serial = p.splitk_serial
desp.split_k_parallel = p.splitk_parallel
desp.shuffle_type = ShuffleStrideTypeCpp(p.shuffle_stride.value)
desp.access_per_vector = p.access_per_vector
desp.is_nvrtc = p.is_nvrtc
def get_gemm_algo_desp_from_param(p: GemmAlgoParams):
desp = GemmAlgoDesp()
_assign_gemm_desp_props(desp, p)
return desp
def get_conv_algo_desp_from_param(p: ConvAlgoParams):
desp = ConvAlgoDesp(p.ndim, ConvOpTypeCpp(p.op_type.value))
_assign_gemm_desp_props(desp, p)
# conv attrs
desp.ndim = p.ndim
desp.op_type = ConvOpTypeCpp(p.op_type.value)
desp.iter_algo = ConvIterAlgoCpp(p.iter_algo.value)
desp.layout_i = ConvLayoutTypeCpp(p.layout_desp_input.layout_type.value)
desp.layout_w = ConvLayoutTypeCpp(p.layout_desp_weight.layout_type.value)
desp.layout_o = ConvLayoutTypeCpp(p.layout_desp_output.layout_type.value)
desp.interleave_i = p.layout_desp_input.interleave
desp.interleave_w = p.layout_desp_weight.interleave
desp.interleave_o = p.layout_desp_output.interleave
desp.mask_sparse = p.mask_sparse
desp.increment_k_first = p.increment_k_first
return desp
def _assign_gemm_params(desp: Union[ConvAlgoDesp, GemmAlgoDesp],
p: Union[GemmAlgoParams, ConvAlgoParams]):
p.dtype_a = dtypes.get_dtype_from_tvdtype(desp.dtype_a)
p.dtype_b = dtypes.get_dtype_from_tvdtype(desp.dtype_b)
p.dtype_c = dtypes.get_dtype_from_tvdtype(desp.dtype_c)
p.dtype_acc = dtypes.get_dtype_from_tvdtype(desp.dacc)
p.dtype_comp = dtypes.get_dtype_from_tvdtype(desp.dcomp)
p.trans_a = desp.trans_a
p.trans_b = desp.trans_b
p.trans_c = desp.trans_c
p.ts = MetaArray(*desp.tile_shape)
p.wts = MetaArray(*desp.warp_tile_shape)
if desp.tensorop[0] > 0:
p.tensorop = TensorOp(
(desp.tensorop[0], desp.tensorop[1], desp.tensorop[2]))
p.num_stage = desp.num_stage
p.algo = GemmAlgo(desp.algo)
p.splitk_serial = desp.split_k_serial
p.splitk_parallel = desp.split_k_parallel
p.shuffle_stride = ShuffleStrideType(desp.shuffle_type.value)
p.access_per_vector = desp.access_per_vector
p.is_nvrtc = desp.is_nvrtc
def get_gemm_param_from_desp(desp: GemmAlgoDesp):
p = GemmAlgoParams((0, 0, 0), (0, 0, 0), 0, "s8,s8,s8,s8,s8", False, False,
False, GemmAlgo.Simt)
_assign_gemm_params(desp, p)
return p
def get_conv_param_from_desp(desp: ConvAlgoDesp):
p = ConvAlgoParams(desp.ndim, ConvOpType.kForward, ConvIterAlgo.Optimized,
(0, 0, 0), (0, 0, 0), 0, "s8,s8,s8,s8,s8", NHWC, NHWC, NHWC,
GemmAlgo.Simt)
_assign_gemm_params(desp, p)
# conv attrs
p.ndim = desp.ndim
p.op_type = ConvOpType(desp.op_type.value)
p.iter_algo = ConvIterAlgo(desp.iter_algo.value)
p.layout_desp_input = ConvLayout(ConvLayoutType(desp.layout_i.value),
desp.interleave_i)
p.layout_desp_weight = ConvLayout(ConvLayoutType(desp.layout_w.value),
desp.interleave_w)
p.layout_desp_output = ConvLayout(ConvLayoutType(desp.layout_o.value),
desp.interleave_o)
p.mask_sparse = desp.mask_sparse
p.increment_k_first = desp.increment_k_first
return p
...@@ -29,21 +29,20 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable( ...@@ -29,21 +29,20 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
from cumm.common import CompileInfo from cumm.common import CompileInfo
from spconv.csrc.sparse.all import SpconvOps from spconv.csrc.sparse.all import SpconvOps
from spconv.csrc.sparse.alloc import ExternalAllocator
from spconv.csrc.utils import BoxOps from spconv.csrc.utils import BoxOps
from spconv.csrc.hash.core import HashTable from spconv.csrc.hash.core import HashTable
all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS
cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
SHUFFLE_TURING_PARAMS) cu = GemmMainUnitTest(all_shuffle)
cu.namespace = "cumm.gemm.main" cu.namespace = "cumm.gemm.main"
convcu = ConvMainUnitTest(IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS +
IMPLGEMM_TURING_PARAMS) IMPLGEMM_TURING_PARAMS)
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
convcu = ConvMainUnitTest(all_imp)
convcu.namespace = "cumm.conv.main" convcu.namespace = "cumm.conv.main"
objects_folder = None pccm.builder.build_pybind([cu, convcu, SpconvOps(), BoxOps(), HashTable(), CompileInfo(), ExternalAllocator()],
if InWindows:
# windows have command line limit, so we use objects_folder to reduce command size.
objects_folder = "objects"
pccm.builder.build_pybind([cu, convcu, SpconvOps(), BoxOps(), HashTable(), CompileInfo()],
PACKAGE_ROOT / "core_cc", PACKAGE_ROOT / "core_cc",
namespace_root=PACKAGE_ROOT, namespace_root=PACKAGE_ROOT,
objects_folder=objects_folder,
load_library=False) load_library=False)
...@@ -16,6 +16,8 @@ import os ...@@ -16,6 +16,8 @@ import os
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from pccm.utils import project_is_editable, project_is_installed from pccm.utils import project_is_editable, project_is_installed
from cumm.gemm.constants import NVRTCMode
import enum
PACKAGE_NAME = "spconv" PACKAGE_NAME = "spconv"
PACKAGE_ROOT = Path(__file__).parent.resolve() PACKAGE_ROOT = Path(__file__).parent.resolve()
...@@ -43,3 +45,21 @@ else: ...@@ -43,3 +45,21 @@ else:
# for f16 backward weight, larger splitk, larger compute error. # for f16 backward weight, larger splitk, larger compute error.
# so we use this env to control maximum splitk. # so we use this env to control maximum splitk.
SPCONV_BWD_SPLITK = list(map(int, os.getenv("SPCONV_BWD_SPLITK", "1,2,4,8,16,32,64").split(","))) SPCONV_BWD_SPLITK = list(map(int, os.getenv("SPCONV_BWD_SPLITK", "1,2,4,8,16,32,64").split(",")))
SPCONV_NVRTC_MODE = NVRTCMode.ConstantMemory
SPCONV_DEBUG_NVRTC_KERNELS = False
class SpconvAllocatorKeys:
Pair = "Pair"
IndiceNumPerLoc = "IndiceNumPerLoc"
PairMask = "PairMask"
MaskArgSort = "MaskArgSort"
OutIndices = "OutIndices"
PairFwd = "PairFwd"
# PairMaskFwd = "PairMaskFwd"
PairMaskBwd = "PairMaskBwd"
# MaskArgSortFwd = "MaskArgSortFwd"
MaskArgSortBwd = "MaskArgSortBwd"
OutFeatures = "OutFeatures"
...@@ -15,17 +15,17 @@ from enum import Enum ...@@ -15,17 +15,17 @@ from enum import Enum
from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgoParams from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgoParams
from cumm.gemm import kernel from cumm.gemm import kernel
from typing import List from typing import List
from cumm.gemm.algospec.core import TensorOpParams from cumm.gemm.algospec.core import TensorOp
from cumm.conv.main import gen_gemm_params as gen_conv_params, ConvFwdAndBwdInput, ConvBwdWeight, ConvIterAlgo, GemmAlgo from cumm.conv.main import gen_gemm_params as gen_conv_params, ConvFwdAndBwdInput, ConvBwdWeight, ConvIterAlgo, GemmAlgo
from cumm.conv.bases import (NCHW, NHWC, ConvEnum, ConvIterAlgo, ConvLayout, from cumm.conv.bases import (NCHW, NHWC, ConvIterAlgo, ConvLayout,
ConvLayoutType, ConvMode, ConvOpType) ConvLayoutType, ConvMode, ConvOpType)
from spconv.constants import NDIM_DONT_CARE from spconv.constants import NDIM_DONT_CARE
class ConvAlgo(Enum): class ConvAlgo(Enum):
Native = "Native" Native = 0
MaskImplicitGemm = "MaskImplicitGemm" MaskImplicitGemm = 1
MaskSplitImplicitGemm = "MaskSplitImplicitGemm" MaskSplitImplicitGemm = 2
class AlgoHint(Enum): class AlgoHint(Enum):
...@@ -40,17 +40,17 @@ class AlgoHint(Enum): ...@@ -40,17 +40,17 @@ class AlgoHint(Enum):
SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [ SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [
*gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None), 2, kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True),
*gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None), 2, kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True),
*gen_shuffle_params((128, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], *gen_shuffle_params((128, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"],
"", 2, kernel.GemmAlgo.SimtDP4A, None), "", 2, kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True),
*gen_shuffle_params( *gen_shuffle_params(
(128, 128, 32), (128, 128, 32),
(64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2, (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
kernel.GemmAlgo.SimtDP4A, None), kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True),
*gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None), 2, kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True),
*gen_shuffle_params((64, 256, 8), (32, 64, 8), ["f32,f32,f32,f32,f32"], *gen_shuffle_params((64, 256, 8), (32, 64, 8), ["f32,f32,f32,f32,f32"],
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
# *gen_shuffle_params( # *gen_shuffle_params(
...@@ -104,88 +104,88 @@ SHUFFLE_VOLTA_PARAMS: List[GemmAlgoParams] = [ ...@@ -104,88 +104,88 @@ SHUFFLE_VOLTA_PARAMS: List[GemmAlgoParams] = [
*gen_shuffle_params( *gen_shuffle_params(
(64, 64, 32), (64, 64, 32),
(32, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (32, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Volta, TensorOpParams((8, 8, 4))), kernel.GemmAlgo.Volta, TensorOp((8, 8, 4))),
# *gen_shuffle_params( # *gen_shuffle_params(
# (128, 128, 32), # (128, 128, 32),
# (64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, # (64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
# kernel.GemmAlgo.Volta, TensorOpParams((8, 8, 4))), # kernel.GemmAlgo.Volta, TensorOp((8, 8, 4))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 256, 32), (128, 256, 32),
(64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Volta, TensorOpParams((8, 8, 4))), kernel.GemmAlgo.Volta, TensorOp((8, 8, 4))),
*gen_shuffle_params( *gen_shuffle_params(
(256, 128, 32), (256, 128, 32),
(64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Volta, TensorOpParams((8, 8, 4))), kernel.GemmAlgo.Volta, TensorOp((8, 8, 4))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 64, 32), (128, 64, 32),
(64, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (64, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Volta, TensorOpParams((8, 8, 4))), kernel.GemmAlgo.Volta, TensorOp((8, 8, 4))),
*gen_shuffle_params( *gen_shuffle_params(
(64, 128, 32), (64, 128, 32),
(32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Volta, TensorOpParams((8, 8, 4))), kernel.GemmAlgo.Volta, TensorOp((8, 8, 4))),
] ]
# SHUFFLE_VOLTA_PARAMS = [] # SHUFFLE_VOLTA_PARAMS = []
SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
*gen_shuffle_params( *gen_shuffle_params(
(64, 64, 32), (64, 64, 32),
(32, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (32, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 128, 32), (128, 128, 32),
(32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
# *gen_shuffle_params( # *gen_shuffle_params(
# (128, 128, 32), # (128, 128, 32),
# (64, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, # (64, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
# kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), # kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params( *gen_shuffle_params(
(64, 64, 64), (64, 64, 64),
(32, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (32, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params( *gen_shuffle_params(
(64, 128, 64), (64, 128, 64),
(32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 256, 32), (128, 256, 32),
(64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params( *gen_shuffle_params(
(256, 128, 32), (256, 128, 32),
(64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 64, 32), (128, 64, 32),
(64, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (64, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params( *gen_shuffle_params(
(64, 128, 32), (64, 128, 32),
(32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16)), is_nvrtc=True),
*gen_shuffle_params( *gen_shuffle_params(
(128, 128, 32), (128, 128, 32),
(32, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, (32, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOpParams((8, 8, 16))), TensorOp((8, 8, 16)), is_nvrtc=True),
# *gen_shuffle_params( # *gen_shuffle_params(
# (128, 128, 32), # (128, 128, 32),
# (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2, # (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
# kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), # kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 256, 32), (128, 256, 32),
(64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, (64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOpParams((8, 8, 16))), TensorOp((8, 8, 16)), is_nvrtc=True),
*gen_shuffle_params( *gen_shuffle_params(
(256, 128, 32), (256, 128, 32),
(64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, (64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOpParams((8, 8, 16))), TensorOp((8, 8, 16)), is_nvrtc=True),
*gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16)), is_nvrtc=True),
*gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16)), is_nvrtc=True),
] ]
# SHUFFLE_TURING_PARAMS = [] # SHUFFLE_TURING_PARAMS = []
...@@ -399,6 +399,34 @@ IMPLGEMM_SIMT_PARAMS = [ ...@@ -399,6 +399,34 @@ IMPLGEMM_SIMT_PARAMS = [
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
] ]
IMPLGEMM_SIMT_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 32, 16), (32, 32, 8),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Simt,
None,
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvBwdWeight, (64, 32, 16), (32, 32, 8),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Simt,
None,
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
]
IMPLGEMM_VOLTA_PARAMS = [ IMPLGEMM_VOLTA_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
...@@ -408,7 +436,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -408,7 +436,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Volta, GemmAlgo.Volta,
TensorOpParams((8, 8, 4)), TensorOp((8, 8, 4)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -420,7 +448,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -420,7 +448,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Volta, GemmAlgo.Volta,
TensorOpParams((8, 8, 4)), TensorOp((8, 8, 4)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=0), access_per_vector=0),
...@@ -432,7 +460,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -432,7 +460,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Volta, GemmAlgo.Volta,
TensorOpParams((8, 8, 4)), TensorOp((8, 8, 4)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -444,7 +472,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -444,7 +472,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Volta, GemmAlgo.Volta,
TensorOpParams((8, 8, 4)), TensorOp((8, 8, 4)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -456,7 +484,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -456,7 +484,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Volta, GemmAlgo.Volta,
TensorOpParams((8, 8, 4)), TensorOp((8, 8, 4)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -468,7 +496,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -468,7 +496,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Volta, GemmAlgo.Volta,
TensorOpParams((8, 8, 4)), TensorOp((8, 8, 4)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=0), access_per_vector=0),
...@@ -480,7 +508,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -480,7 +508,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Volta, GemmAlgo.Volta,
TensorOpParams((8, 8, 4)), TensorOp((8, 8, 4)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -495,7 +523,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -495,7 +523,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=0), access_per_vector=0),
...@@ -507,7 +535,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -507,7 +535,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -519,7 +547,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -519,7 +547,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -531,7 +559,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -531,7 +559,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -543,7 +571,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -543,7 +571,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -555,7 +583,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -555,7 +583,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -567,7 +595,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -567,7 +595,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -579,7 +607,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -579,7 +607,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -591,7 +619,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -591,7 +619,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -603,7 +631,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -603,7 +631,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -615,7 +643,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -615,7 +643,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -628,7 +656,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -628,7 +656,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=0), access_per_vector=0),
...@@ -641,7 +669,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -641,7 +669,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -654,12 +682,16 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -654,12 +682,16 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
# *gen_conv_params(ConvBwdWeight, (32, 64, 32), (32, 32, 16), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, "f16,f16,f16,f32,f32", # *gen_conv_params(ConvBwdWeight, (32, 64, 32), (32, 32, 16), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, "f16,f16,f16,f32,f32",
# NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOpParams((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1), # NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOp((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1),
# gen_conv_params(ConvFwdAndBwdInput, ) # gen_conv_params(ConvFwdAndBwdInput, )
] ]
ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS
ALL_IMPGEMM_PARAMS = IMPLGEMM_SIMT_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_VOLTA_PARAMS
...@@ -34,13 +34,15 @@ class SpconvOps: ...@@ -34,13 +34,15 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def generate_conv_inds_stage2(indices: Tensor, hashdata: Tensor, indice_pairs: Tensor, indice_pairs_uniq: Tensor, out_inds: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int: def generate_conv_inds_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
""" """
Args: Args:
indices: indices:
hashdata: hashdata_k:
hashdata_v:
indice_pairs: indice_pairs:
indice_pairs_uniq: indice_pairs_uniq:
indice_pairs_uniq_before_sort:
out_inds: out_inds:
num_out_act: num_out_act:
batch_size: batch_size:
...@@ -74,14 +76,16 @@ class SpconvOps: ...@@ -74,14 +76,16 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int: def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
""" """
Args: Args:
indices: indices:
hashdata: hashdata_k:
hashdata_v:
indice_pairs_fwd: indice_pairs_fwd:
indice_pairs_bwd: indice_pairs_bwd:
indice_pairs_uniq: indice_pairs_uniq:
indice_pairs_uniq_before_sort:
out_inds: out_inds:
mask_fwd: mask_fwd:
mask_bwd: mask_bwd:
...@@ -98,11 +102,12 @@ class SpconvOps: ...@@ -98,11 +102,12 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def generate_subm_conv_inds(indices: Tensor, hashdata: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int: def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
""" """
Args: Args:
indices: indices:
hashdata: hashdata_k:
hashdata_v:
indice_pairs: indice_pairs:
out_inds: out_inds:
indice_num_per_loc: indice_num_per_loc:
...@@ -276,6 +281,18 @@ class SpconvOps: ...@@ -276,6 +281,18 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def sort_1d_by_key_split_allocator_v2(data: Tensor, allocator, mask: Tensor, indices: Tensor = Tensor(), stream: int = 0, mask_output: bool = False) -> Tensor:
"""
Args:
data:
allocator:
mask:
indices:
stream:
mask_output:
"""
...
@staticmethod
def count_bits(a: Tensor) -> Tensor: def count_bits(a: Tensor) -> Tensor:
""" """
Args: Args:
...@@ -328,3 +345,51 @@ class SpconvOps: ...@@ -328,3 +345,51 @@ class SpconvOps:
stream_int: stream_int:
""" """
... ...
@staticmethod
def get_int32_max() -> int: ...
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0) -> Tensor:
"""
Args:
allocator:
indices:
batch_size:
input_dims:
algo:
ksize:
stride:
padding:
dilation:
out_padding:
subm:
transposed:
is_train:
stream_int:
"""
...
@staticmethod
def get_indice_pairs(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, stream_int: int = 0) -> None:
"""
Args:
allocator:
indices:
batch_size:
input_dims:
algo:
ksize:
stride:
padding:
dilation:
out_padding:
subm:
transposed:
stream_int:
"""
...
@staticmethod
def test_allocator(allocator) -> None:
"""
Args:
allocator:
"""
...
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
class ExternalAllocator:
def zeros(self, name: str, shape: List[int], dtype: int, device: int) -> Tensor:
"""
Args:
name:
shape:
dtype:
device:
"""
...
def empty(self, name: str, shape: List[int], dtype: int, device: int) -> Tensor:
"""
Args:
name:
shape:
dtype:
device:
"""
...
def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int) -> Tensor:
"""
Args:
name:
shape:
value:
dtype:
device:
"""
...
def full_float(self, name: str, shape: List[int], value: float, dtype: int, device: int) -> Tensor:
"""
Args:
name:
shape:
value:
dtype:
device:
"""
...
def free(self, ten: Tensor) -> None:
"""
Args:
ten:
"""
...
def free_noexcept(self, ten: Tensor) -> None:
"""
Args:
ten:
"""
...
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue from pccm.stubs import EnumValue, EnumClassValue
from ...cumm.gemm.main import GemmAlgoDesp from cumm.tensorview.gemm import ConvParams
from cumm.tensorview import Tensor
from cumm.tensorview import CUDAKernelTimer
class ConvAlgoDesp(GemmAlgoDesp):
ndim: int
op_type: int
iter_algo: int
layout_i: int
layout_w: int
layout_o: int
interleave_i: int
interleave_w: int
interleave_o: int
mask_sparse: bool
increment_k_first: bool
def __init__(self, ndim: int, op_type: int) -> None:
"""
Args:
ndim:
op_type:
"""
...
def __repr__(self) -> str: ...
@staticmethod
def conv_iwo_012_to_abc(op_type: int) -> List[int]:
"""
Args:
op_type:
"""
...
@staticmethod
def gemm_abc_012_to_iwo(op_type: int) -> List[int]:
"""
Args:
op_type:
"""
...
@property
def dtype_input(self) -> int: ...
@property
def dtype_weight(self) -> int: ...
@property
def dtype_output(self) -> int: ...
def supported(self, m: int, n: int, k: int, C: int, K: int, mask_width: int) -> bool:
"""
Args:
m:
n:
k:
C:
K:
mask_width:
"""
...
def query_conv_workspace_size(self, m: int, n: int, k: int, split_k_slices: int, kv: int) -> int:
"""
Args:
m:
n:
k:
split_k_slices:
kv:
"""
...
def supported_ldx_conv(self, ldi: int, ldw: int, ldo: int) -> bool:
"""
Args:
ldi:
ldw:
ldo:
"""
...
class ConvParams:
conv_algo_desp: Any
input: Tensor
weight: Tensor
output: Tensor
split_k_slices: int
padding: List[int]
stride: List[int]
dilation: List[int]
alpha: float
beta: float
mask_width: int
mask_filter: int
reverse_mask: bool
verbose: bool
timer: CUDAKernelTimer
workspace: Tensor = Tensor()
mask: Tensor = Tensor()
mask_argsort: Tensor = Tensor()
indices: Tensor = Tensor()
mask_output: Tensor = Tensor()
stream: int
def __init__(self, ndim: int, op_type: int, timer: CUDAKernelTimer = CUDAKernelTimer(False)) -> None:
"""
Args:
ndim:
op_type:
timer:
"""
...
class ConvMainUnitTest: class ConvMainUnitTest:
@staticmethod @staticmethod
def extract_mnk(op_type: int, N: int, C: int, K: int, kernel_volume: int, in_prod: int, out_prod: int, mask_sparse: bool) -> List[int]: def extract_mnk(op_type: int, N: int, C: int, K: int, kernel_volume: int, in_prod: int, out_prod: int, mask_sparse: bool) -> List[int]:
......
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor from cumm.tensorview.gemm import GemmParams
from cumm.tensorview import CUDAKernelTimer
class GemmAlgoDesp:
dtype_a: int
dtype_b: int
dtype_c: int
tile_shape: Tuple[int, int, int]
warp_tile_shape: Tuple[int, int, int]
num_stage: int
dacc: int
dcomp: int
algo: str
tensorop: List[int]
split_k_serial_: int
split_k_parallel_: int
shuffle_type: str
element_per_access_a: int
element_per_access_b: int
element_per_access_c: int
access_per_vector: int
def __init__(self) -> None: ...
def __repr__(self) -> str: ...
@property
def split_k_serial(self) -> bool: ...
@split_k_serial.setter
def split_k_serial(self, val: bool) -> None:
"""
Args:
val:
"""
...
@property
def split_k_parallel(self) -> bool: ...
@split_k_parallel.setter
def split_k_parallel(self, val: bool) -> None:
"""
Args:
val:
"""
...
def check_valid(self) -> None: ...
@property
def trans_a(self) -> bool: ...
@trans_a.setter
def trans_a(self, val: bool) -> None:
"""
Args:
val:
"""
...
@property
def trans_b(self) -> bool: ...
@trans_b.setter
def trans_b(self, val: bool) -> None:
"""
Args:
val:
"""
...
@property
def trans_c(self) -> bool: ...
@trans_c.setter
def trans_c(self, val: bool) -> None:
"""
Args:
val:
"""
...
def query_workspace_size(self, m: int, n: int, k: int, split_k_slices: int) -> int:
"""
Args:
m:
n:
k:
split_k_slices:
"""
...
def supported(self, m: int, n: int, k: int) -> bool:
"""
Args:
m:
n:
k:
"""
...
def supported_ldx(self, lda: int, ldb: int, ldc: int) -> bool:
"""
Args:
lda:
ldb:
ldc:
"""
...
class GemmParams:
algo_desp: GemmAlgoDesp
split_k_slices: int
workspace: Tensor = Tensor()
a_inds: Tensor = Tensor()
b_inds: Tensor = Tensor()
c_inds: Tensor = Tensor()
alpha: float
beta: float
stream: int
timer: CUDAKernelTimer
def __init__(self, timer: CUDAKernelTimer = CUDAKernelTimer(False)) -> None:
"""
Args:
timer:
"""
...
def check_valid(self) -> None: ...
@property
def a(self) -> Tensor: ...
@a.setter
def a(self, val: Tensor) -> None:
"""
Args:
val:
"""
...
@property
def b(self) -> Tensor: ...
@b.setter
def b(self, val: Tensor) -> None:
"""
Args:
val:
"""
...
@property
def c(self) -> Tensor: ...
@c.setter
def c(self, val: Tensor) -> None:
"""
Args:
val:
"""
...
class GemmMainUnitTest: class GemmMainUnitTest:
@staticmethod @staticmethod
def get_all_algo_desp() -> List[GemmAlgoDesp]: ... def get_all_algo_desp() -> List[Any]: ...
@staticmethod @staticmethod
def extract_mnk(a_shape: List[int], b_shape: List[int], trans_a: bool, trans_b: bool, trans_c: bool, shuffle_type: str = "NS", a_inds_shape: List[int] = [], b_inds_shape: List[int] = [], c_inds_shape: List[int] = []) -> Tuple[int, int, int]: def extract_mnk(a_shape: List[int], b_shape: List[int], trans_a: bool, trans_b: bool, trans_c: bool, shuffle_type: str = "0", a_inds_shape: List[int] = [], b_inds_shape: List[int] = [], c_inds_shape: List[int] = []) -> Tuple[int, int, int]:
""" """
Args: Args:
a_shape: a_shape:
......
...@@ -104,6 +104,8 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -104,6 +104,8 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
self.add_member("map_8_8", "tsl::robin_map<uint64_t, uint64_t>") self.add_member("map_8_8", "tsl::robin_map<uint64_t, uint64_t>")
self.add_pybind_member("insert_count_", "int64_t", prop_name="insert_count", readwrite=False) self.add_pybind_member("insert_count_", "int64_t", prop_name="insert_count", readwrite=False)
self.valid_hash_key_types = [dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64]
@pccm.pybind.mark @pccm.pybind.mark
@pccm.constructor @pccm.constructor
def ctor(self): def ctor(self):
...@@ -163,11 +165,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -163,11 +165,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code.raw(f""" code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream); auto custream = reinterpret_cast<cudaStream_t>(stream);
""") """)
for k_items in _dispatch_ints(code, [4, 8], "keys_data.itemsize()"): for k_items in _dispatch(code, self.valid_hash_key_types, "keys_data.dtype()"):
code.raw(f""" code.raw(f"""
using K = tv::hash::itemsize_to_unsigned_t<{k_items}>; using K = {k_items};
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data()); K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
""") """)
for v_items in _dispatch_ints(code, [4, 8], "values_data.itemsize()"): for v_items in _dispatch_ints(code, [4, 8], "values_data.itemsize()"):
...@@ -176,10 +176,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -176,10 +176,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
V* value_data_ptr = reinterpret_cast<V*>(values_data.raw_data()); V* value_data_ptr = reinterpret_cast<V*>(values_data.raw_data());
using table_t = using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>, tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey, false>; tv::hash::default_empty_key_v<K>, false>;
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0)); table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
tv::cuda::Launch launcher(table.size(), custream); tv::cuda::Launch launcher(table.size(), custream);
launcher(tv::hash::clear_table_split<table_t>, table); launcher(tv::hash::clear_map_kernel_split<table_t>, table);
""") """)
return code return code
...@@ -201,9 +201,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -201,9 +201,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
int64_t value_after_insert = keys.dim(0) + insert_count_; int64_t value_after_insert = keys.dim(0) + insert_count_;
TV_ASSERT_RT_ERR(value_after_insert < keys_data.dim(0), "inserted count exceed maximum hash size"); TV_ASSERT_RT_ERR(value_after_insert < keys_data.dim(0), "inserted count exceed maximum hash size");
insert_count_ += keys.dim(0); insert_count_ += keys.dim(0);
TV_ASSERT_RT_ERR(keys.dtype() == keys_data.dtype(), "keys dtype not equal to", keys_data.dtype());
}} }}
auto N = keys.dim(0); auto N = keys.dim(0);
TV_ASSERT_RT_ERR(keys.itemsize() == key_itemsize_, "keys itemsize not equal to", key_itemsize_);
if (!values.empty()){{ if (!values.empty()){{
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_); TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(keys.dim(0) == values.dim(0), "number of key and value must same"); TV_ASSERT_RT_ERR(keys.dim(0) == values.dim(0), "number of key and value must same");
...@@ -231,10 +231,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -231,10 +231,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code.raw(f""" code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream); auto custream = reinterpret_cast<cudaStream_t>(stream);
""") """)
for k_items in _dispatch_ints(code, [4, 8], "keys_data.itemsize()"): for k_items in _dispatch(code, self.valid_hash_key_types, "keys_data.dtype()"):
code.raw(f""" code.raw(f"""
using K = tv::hash::itemsize_to_unsigned_t<{k_items}>; using K = {k_items};
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data()); K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
const K* key_ptr = reinterpret_cast<const K*>(keys.raw_data()); const K* key_ptr = reinterpret_cast<const K*>(keys.raw_data());
...@@ -248,7 +247,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -248,7 +247,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
using table_t = using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>, tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey, false>; tv::hash::default_empty_key_v<K>, false>;
tv::cuda::Launch launcher(N, custream); tv::cuda::Launch launcher(N, custream);
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0)); table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
launcher(tv::hash::insert_split<table_t>, table, key_ptr, value_ptr, size_t(N)); launcher(tv::hash::insert_split<table_t>, table, key_ptr, value_ptr, size_t(N));
...@@ -279,6 +278,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -279,6 +278,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_); TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(N == values.dim(0) && is_empty.dim(0) == N, "number of key and value must same"); TV_ASSERT_RT_ERR(N == values.dim(0) && is_empty.dim(0) == N, "number of key and value must same");
auto is_empty_ptr = is_empty.data_ptr<uint8_t>(); auto is_empty_ptr = is_empty.data_ptr<uint8_t>();
if (!is_cpu){{
TV_ASSERT_RT_ERR(keys.dtype() == keys_data.dtype(), "keys dtype not equal to", keys_data.dtype());
}}
""") """)
with code.if_("is_cpu"): with code.if_("is_cpu"):
map_name = "cpu_map" map_name = "cpu_map"
...@@ -304,10 +306,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -304,10 +306,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code.raw(f""" code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream); auto custream = reinterpret_cast<cudaStream_t>(stream);
""") """)
for k_items in _dispatch_ints(code, [4, 8], "keys_data.itemsize()"): for k_items in _dispatch(code, self.valid_hash_key_types, "keys_data.dtype()"):
code.raw(f""" code.raw(f"""
using K = tv::hash::itemsize_to_unsigned_t<{k_items}>; using K = {k_items};
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data()); K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_ptr = reinterpret_cast<K*>(keys.raw_data()); K* key_ptr = reinterpret_cast<K*>(keys.raw_data());
...@@ -319,7 +320,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -319,7 +320,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
V* value_ptr = reinterpret_cast<V*>(values.raw_data()); V* value_ptr = reinterpret_cast<V*>(values.raw_data());
using table_t = using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>, tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey, false>; tv::hash::default_empty_key_v<K>, false>;
tv::cuda::Launch launcher(N, custream); tv::cuda::Launch launcher(N, custream);
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0)); table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
launcher(tv::hash::query_split<table_t>, table, key_ptr, value_ptr, is_empty_ptr, size_t(N)); launcher(tv::hash::query_split<table_t>, table, key_ptr, value_ptr, is_empty_ptr, size_t(N));
...@@ -361,11 +362,12 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -361,11 +362,12 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
TV_ASSERT_RT_ERR(count.device() == 0, "count must be cuda"); TV_ASSERT_RT_ERR(count.device() == 0, "count must be cuda");
auto custream = reinterpret_cast<cudaStream_t>(stream); auto custream = reinterpret_cast<cudaStream_t>(stream);
""") """)
for k_items in _dispatch_ints(code, [4, 8], "keys_data.itemsize()"): for k_items in _dispatch(code, self.valid_hash_key_types, "keys_data.dtype()"):
code.raw(f""" code.raw(f"""
using K = tv::hash::itemsize_to_unsigned_t<{k_items}>; using K = {k_items};
constexpr K kEmptyKey = std::numeric_limits<K>::max(); using Kunsigned = tv::hash::itemsize_to_unsigned_t<sizeof(K)>;
auto count_ptr = count.data_ptr<K>();
auto count_ptr = count.data_ptr<Kunsigned>();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data()); K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
""") """)
...@@ -376,10 +378,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -376,10 +378,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
V* value_data_ptr = reinterpret_cast<V*>(values_data.raw_data()); V* value_data_ptr = reinterpret_cast<V*>(values_data.raw_data());
using table_t = using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>, tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey, false>; tv::hash::default_empty_key_v<K>, false>;
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0)); table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
tv::cuda::Launch launcher(table.size(), custream); tv::cuda::Launch launcher(table.size(), custream);
launcher(tv::hash::assign_arange_split<table_t, K>, table, count_ptr); launcher(tv::hash::assign_arange_split<table_t, Kunsigned>, table, count_ptr);
""") """)
else: else:
code.raw(f""" code.raw(f"""
...@@ -426,7 +428,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -426,7 +428,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
TV_ASSERT_RT_ERR(keys.itemsize() == key_itemsize_, "keys itemsize not equal to", key_itemsize_); TV_ASSERT_RT_ERR(keys.itemsize() == key_itemsize_, "keys itemsize not equal to", key_itemsize_);
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_); TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(N == values.dim(0), "number of key and value must same"); TV_ASSERT_RT_ERR(N == values.dim(0), "number of key and value must same");
if (!is_cpu){{
TV_ASSERT_RT_ERR(keys.dtype() == keys_data.dtype(), "keys dtype not equal to", keys_data.dtype());
}}
""") """)
with code.if_("is_cpu"): with code.if_("is_cpu"):
map_name = "cpu_map" map_name = "cpu_map"
...@@ -450,12 +454,12 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -450,12 +454,12 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code.raw(f""" code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream); auto custream = reinterpret_cast<cudaStream_t>(stream);
""") """)
for k_items in _dispatch_ints(code, [4, 8], "keys_data.itemsize()"): for k_items in _dispatch(code, self.valid_hash_key_types, "keys_data.dtype()"):
code.raw(f""" code.raw(f"""
using K = tv::hash::itemsize_to_unsigned_t<{k_items}>; using K = {k_items};
auto count_ptr = count.data_ptr<K>(); using Kunsigned = tv::hash::itemsize_to_unsigned_t<sizeof(K)>;
auto count_ptr = count.data_ptr<Kunsigned>();
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data()); K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_ptr = reinterpret_cast<K*>(keys.raw_data()); K* key_ptr = reinterpret_cast<K*>(keys.raw_data());
...@@ -467,10 +471,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -467,10 +471,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
V* value_ptr = reinterpret_cast<V*>(values.raw_data()); V* value_ptr = reinterpret_cast<V*>(values.raw_data());
using table_t = using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>, tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey, false>; tv::hash::default_empty_key_v<K>, false>;
tv::cuda::Launch launcher(N, custream); tv::cuda::Launch launcher(N, custream);
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0)); table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
launcher(tv::hash::iterate_table_split<table_t, K>, table, key_ptr, value_ptr, size_t(N), count_ptr); launcher(tv::hash::iterate_table_split<table_t, Kunsigned>, table, key_ptr, value_ptr, size_t(N), count_ptr);
""") """)
else: else:
code.raw(f""" code.raw(f"""
...@@ -523,10 +527,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -523,10 +527,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code.raw(f""" code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream); auto custream = reinterpret_cast<cudaStream_t>(stream);
""") """)
for k_items in _dispatch_ints(code, [4, 8], "keys_data.itemsize()"): for k_items in _dispatch(code, self.valid_hash_key_types, "keys_data.dtype()"):
code.raw(f""" code.raw(f"""
using K = tv::hash::itemsize_to_unsigned_t<{k_items}>; using K = {k_items};
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data()); K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
const K* key_ptr = reinterpret_cast<const K*>(keys.raw_data()); const K* key_ptr = reinterpret_cast<const K*>(keys.raw_data());
...@@ -538,7 +541,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -538,7 +541,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
const V* value_ptr = reinterpret_cast<const V*>(values.raw_data()); const V* value_ptr = reinterpret_cast<const V*>(values.raw_data());
using table_t = using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>, tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey, false>; tv::hash::default_empty_key_v<K>, false>;
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0)); table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
tv::cuda::Launch launcher(N, custream); tv::cuda::Launch launcher(N, custream);
launcher(insert_exist_keys_kernel<table_t>, table, key_ptr, value_ptr, is_empty_ptr, size_t(N)); launcher(insert_exist_keys_kernel<table_t>, table, key_ptr, value_ptr, is_empty_ptr, size_t(N));
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from cumm.common import TensorView, TensorViewCPU, TensorViewKernel, ThrustLib from cumm.common import TensorView, TensorViewCPU, TensorViewKernel, ThrustLib, GemmBasicHost
from cumm.conv.bases import ConvOpType, NHWC from cumm.conv.bases import ConvOpType, NHWC
from cumm.conv.params import ConvProblem from cumm.conv.params import ConvProblem
from cumm import dtypes from cumm import dtypes
...@@ -23,7 +23,8 @@ from .pointops import Point2Voxel, Point2VoxelCPU ...@@ -23,7 +23,8 @@ from .pointops import Point2Voxel, Point2VoxelCPU
from .indices import SparseConvIndicesKernel, CudaCommonKernel, SparseConvIndicesCPU from .indices import SparseConvIndicesKernel, CudaCommonKernel, SparseConvIndicesCPU
from .maxpool import IndiceMaxPool, IndiceMaxPoolCPU from .maxpool import IndiceMaxPool, IndiceMaxPoolCPU
from .gather import GatherCPU from .gather import GatherCPU
from .alloc import ExternalAllocator, ThrustAllocator
from spconv.constants import SpconvAllocatorKeys
class CustomThrustLib(pccm.Class): class CustomThrustLib(pccm.Class):
def __init__(self): def __init__(self):
...@@ -31,7 +32,7 @@ class CustomThrustLib(pccm.Class): ...@@ -31,7 +32,7 @@ class CustomThrustLib(pccm.Class):
self.add_dependency(ThrustLib) self.add_dependency(ThrustLib)
# https://github.com/NVIDIA/thrust/issues/1401#issuecomment-806403746 # https://github.com/NVIDIA/thrust/issues/1401#issuecomment-806403746
if compat.InLinux: if compat.InLinux:
self.build_meta.add_cflags("nvcc", "-Xcompiler", "-fno-gnu-unique") self.build_meta.add_public_cflags("nvcc", "-Xcompiler", "-fno-gnu-unique")
class ThrustCustomAllocatorV2(pccm.Class, pccm.pybind.PybindClassMixin): class ThrustCustomAllocatorV2(pccm.Class, pccm.pybind.PybindClassMixin):
...@@ -65,13 +66,13 @@ class ThrustCustomAllocatorV2(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -65,13 +66,13 @@ class ThrustCustomAllocatorV2(pccm.Class, pccm.pybind.PybindClassMixin):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.arg("ptr", "char *") code.arg("ptr", "char *")
code.arg("num_bytes", "size_t") code.arg("num_bytes", "size_t")
return code return code
class SpconvOps(pccm.Class): class SpconvOps(pccm.Class):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.add_dependency(ThrustCustomAllocatorV2) self.add_dependency(ThrustCustomAllocatorV2, ExternalAllocator, GemmBasicHost, ThrustAllocator)
self.ndims = [1, 2, 3, 4] self.ndims = [1, 2, 3, 4]
for ndim in self.ndims: for ndim in self.ndims:
p2v = Point2Voxel(dtypes.float32, ndim) p2v = Point2Voxel(dtypes.float32, ndim)
...@@ -167,8 +168,8 @@ class SpconvOps(pccm.Class): ...@@ -167,8 +168,8 @@ class SpconvOps(pccm.Class):
@pccm.cuda.static_function @pccm.cuda.static_function
def generate_conv_inds_stage2(self): def generate_conv_inds_stage2(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.arg("indices, hashdata", "tv::Tensor") code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor")
code.arg("indice_pairs, indice_pairs_uniq, out_inds", "tv::Tensor") code.arg("indice_pairs, indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds", "tv::Tensor")
code.arg("num_out_act", "int") code.arg("num_out_act", "int")
code.arg("batch_size", "int") code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"std::vector<int>") code.arg("output_dims, input_dims", f"std::vector<int>")
...@@ -198,8 +199,9 @@ class SpconvOps(pccm.Class): ...@@ -198,8 +199,9 @@ class SpconvOps(pccm.Class):
padding_[i] = padding[i]; padding_[i] = padding[i];
dilation_[i] = dilation[i]; dilation_[i] = dilation[i];
}} }}
return SpconvIndices{ndim}D::generate_conv_inds_stage2(indices, hashdata, return SpconvIndices{ndim}D::generate_conv_inds_stage2(indices,
indice_pairs, indice_pairs_uniq, out_inds, num_out_act, hashdata_k, hashdata_v, indice_pairs,
indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds, num_out_act,
batch_size, output_dims_, input_dims_, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int); ksize_, stride_, padding_, dilation_, transposed, stream_int);
}} }}
...@@ -260,9 +262,9 @@ class SpconvOps(pccm.Class): ...@@ -260,9 +262,9 @@ class SpconvOps(pccm.Class):
code = pccm.FunctionCode() code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
return code.make_invalid() return code.make_invalid()
code.arg("indices, hashdata", "tv::Tensor") code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor")
code.arg( code.arg(
"indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq, out_inds", "indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds",
"tv::Tensor") "tv::Tensor")
code.arg("mask_fwd, mask_bwd", "tv::Tensor") code.arg("mask_fwd, mask_bwd", "tv::Tensor")
code.arg("num_out_act", "int") code.arg("num_out_act", "int")
...@@ -291,8 +293,11 @@ class SpconvOps(pccm.Class): ...@@ -291,8 +293,11 @@ class SpconvOps(pccm.Class):
padding_[i] = padding[i]; padding_[i] = padding[i];
dilation_[i] = dilation[i]; dilation_[i] = dilation[i];
}} }}
return SpconvIndices{ndim}D::generate_conv_inds_stage2_mask(indices, hashdata, return SpconvIndices{ndim}D::generate_conv_inds_stage2_mask(
indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq, out_inds, mask_fwd, mask_bwd, indices, hashdata_k, hashdata_v,
indice_pairs_fwd, indice_pairs_bwd,
indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_, num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int); ksize_, stride_, padding_, dilation_, transposed, stream_int);
}} }}
...@@ -307,7 +312,7 @@ class SpconvOps(pccm.Class): ...@@ -307,7 +312,7 @@ class SpconvOps(pccm.Class):
code = pccm.FunctionCode() code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
return code.make_invalid() return code.make_invalid()
code.arg("indices, hashdata", "tv::Tensor") code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor")
code.arg("indice_pairs, out_inds, indice_num_per_loc", "tv::Tensor") code.arg("indice_pairs, out_inds, indice_num_per_loc", "tv::Tensor")
code.arg("batch_size", "int") code.arg("batch_size", "int")
code.arg("input_dims", f"std::vector<int>") code.arg("input_dims", f"std::vector<int>")
...@@ -331,7 +336,8 @@ class SpconvOps(pccm.Class): ...@@ -331,7 +336,8 @@ class SpconvOps(pccm.Class):
ksize_[i] = ksize[i]; ksize_[i] = ksize[i];
dilation_[i] = dilation[i]; dilation_[i] = dilation[i];
}} }}
return SpconvIndices{ndim}D::generate_subm_conv_inds(indices, hashdata, return SpconvIndices{ndim}D::generate_subm_conv_inds(indices,
hashdata_k, hashdata_v,
indice_pairs, out_inds, indice_num_per_loc, indice_pairs, out_inds, indice_num_per_loc,
batch_size, input_dims_, batch_size, input_dims_,
ksize_, dilation_, indice_pair_mask, backward, ksize_, dilation_, indice_pair_mask, backward,
...@@ -566,7 +572,7 @@ class SpconvOps(pccm.Class): ...@@ -566,7 +572,7 @@ class SpconvOps(pccm.Class):
}} }}
""" """
code.add_dependency(ThrustLib, TensorViewKernel) code.add_dependency(CustomThrustLib, TensorViewKernel)
code.add_param_class("cudakers", CudaCommonKernel()) code.add_param_class("cudakers", CudaCommonKernel())
code.raw(f""" code.raw(f"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream); cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
...@@ -588,14 +594,15 @@ class SpconvOps(pccm.Class): ...@@ -588,14 +594,15 @@ class SpconvOps(pccm.Class):
""") """)
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
@pccm.pybind.mark def sort_1d_by_key_allocator_template(self, use_allocator: bool):
@pccm.cuda.static_function
def sort_1d_by_key_allocator(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
return code.make_invalid() return code.make_invalid()
code.arg("data", "tv::Tensor") code.arg("data", "tv::Tensor")
code.arg("alloc_func", "std::function<std::uintptr_t(std::size_t)>") if not use_allocator:
code.arg("alloc_func", "std::function<std::uintptr_t(std::size_t)>")
else:
code.arg("allocator", "ThrustAllocator&")
code.arg("indices", code.arg("indices",
"tv::Tensor", "tv::Tensor",
...@@ -614,10 +621,13 @@ class SpconvOps(pccm.Class): ...@@ -614,10 +621,13 @@ class SpconvOps(pccm.Class):
}} }}
}} }}
""" """
code.add_dependency(ThrustLib, TensorViewKernel) code.add_dependency(CustomThrustLib, TensorViewKernel)
code.add_param_class("cudakers", CudaCommonKernel()) code.add_param_class("cudakers", CudaCommonKernel())
if not use_allocator:
code.raw(f"""
ThrustCustomAllocatorV2 allocator{{alloc_func}};
""")
code.raw(f""" code.raw(f"""
ThrustCustomAllocatorV2 allocator{{alloc_func}};
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream); cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
if (indices.empty()){{ if (indices.empty()){{
indices = tv::empty({{data.dim(0)}}, tv::int32, 0); indices = tv::empty({{data.dim(0)}}, tv::int32, 0);
...@@ -638,6 +648,19 @@ class SpconvOps(pccm.Class): ...@@ -638,6 +648,19 @@ class SpconvOps(pccm.Class):
""") """)
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.cuda.static_function
def sort_1d_by_key_allocator(self):
# for python
return self.sort_1d_by_key_allocator_template(False)
@pccm.cuda.static_function
def sort_1d_by_key_allocator_v2(self):
# for cpp only
return self.sort_1d_by_key_allocator_template(True)
@pccm.pybind.mark @pccm.pybind.mark
@pccm.cuda.static_function @pccm.cuda.static_function
def sort_1d_by_key_split(self): def sort_1d_by_key_split(self):
...@@ -694,14 +717,15 @@ class SpconvOps(pccm.Class): ...@@ -694,14 +717,15 @@ class SpconvOps(pccm.Class):
""") """)
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
@pccm.pybind.mark def sort_1d_by_key_split_allocator_template(self, use_allocator: bool):
@pccm.cuda.static_function
def sort_1d_by_key_split_allocator(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
return code.make_invalid() return code.make_invalid()
code.arg("data", "tv::Tensor") code.arg("data", "tv::Tensor")
code.arg("alloc_func", "std::function<std::uintptr_t(std::size_t)>") if not use_allocator:
code.arg("alloc_func", "std::function<std::uintptr_t(std::size_t)>")
else:
code.arg("allocator", "ThrustAllocator&")
code.arg("mask", "tv::Tensor") code.arg("mask", "tv::Tensor")
...@@ -727,9 +751,11 @@ class SpconvOps(pccm.Class): ...@@ -727,9 +751,11 @@ class SpconvOps(pccm.Class):
""" """
code.add_dependency(CustomThrustLib, TensorViewKernel) code.add_dependency(CustomThrustLib, TensorViewKernel)
code.add_param_class("cudakers", CudaCommonKernel()) code.add_param_class("cudakers", CudaCommonKernel())
if not use_allocator:
code.raw(f"""
ThrustCustomAllocatorV2 allocator{{alloc_func}};
""")
code.raw(f""" code.raw(f"""
ThrustCustomAllocatorV2 allocator{{alloc_func}};
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream); cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
// auto timer = tv::CudaContextTimer<>(); // auto timer = tv::CudaContextTimer<>();
if (indices.empty()){{ if (indices.empty()){{
...@@ -755,6 +781,18 @@ class SpconvOps(pccm.Class): ...@@ -755,6 +781,18 @@ class SpconvOps(pccm.Class):
""") """)
return code.ret("tv::Tensor") return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.cuda.static_function
def sort_1d_by_key_split_allocator(self):
return self.sort_1d_by_key_split_allocator_template(False)
@pccm.pybind.mark
@pccm.cuda.static_function
def sort_1d_by_key_split_allocator_v2(self):
return self.sort_1d_by_key_split_allocator_template(True)
@pccm.pybind.mark @pccm.pybind.mark
@pccm.cuda.static_function @pccm.cuda.static_function
def count_bits(self): def count_bits(self):
...@@ -947,3 +985,411 @@ class SpconvOps(pccm.Class): ...@@ -947,3 +985,411 @@ class SpconvOps(pccm.Class):
""") """)
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""") code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
return code.ret("std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>") return code.ret("std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>")
@pccm.pybind.mark
@pccm.static_function
def get_int32_max(self):
code = pccm.FunctionCode()
code.raw(f"return std::numeric_limits<int>::max();")
return code.ret("int")
@pccm.static_function
def get_conv_output_size(self):
code = pccm.FunctionCode()
code.arg("input_dims", f"std::vector<int>")
code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.raw(f"""
int ndim = input_dims.size();
std::vector<int> out_dims;
for (int i = 0; i < ndim; ++i){{
if (ksize[i] == -1){{
out_dims.push_back(1);
}}else{{
auto size = (input_dims[i] + 2 * padding[i] - dilation[i] *
(ksize[i] - 1) - 1) / stride[i] + 1;
out_dims.push_back(size);
}}
}}
return out_dims;
""")
return code.ret("std::vector<int>")
@pccm.static_function
def get_deconv_output_size(self):
code = pccm.FunctionCode()
code.arg("input_dims", f"std::vector<int>")
code.arg("ksize, stride, padding, dilation, output_padding", f"std::vector<int>")
code.raw(f"""
int ndim = input_dims.size();
std::vector<int> out_dims;
for (int i = 0; i < ndim; ++i){{
if (ksize[i] == -1){{
TV_THROW_INVALID_ARG("kernel size can't be -1");
}}else{{
auto size = (input_dims[i] - 1) * stride[i] - 2 * padding[i] + ksize[
i] + output_padding[i];
out_dims.push_back(size);
}}
}}
return out_dims;
""")
return code.ret("std::vector<int>")
@pccm.cuda.static_function
def apply_thrust_unique_to_indice_pairs_uniq(self):
code = pccm.code()
code.add_dependency(CustomThrustLib)
code.arg("data", "tv::Tensor")
code.arg("allocator", "ThrustAllocator&")
code.arg("stream_int", f"std::uintptr_t", "0")
code.raw(f"""
int num_out_act = 0;
int uniq_size = data.dim(0);
tv::dispatch<int32_t, int64_t>(data.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
thrust::device_ptr<T> ptr_tr(data.data_ptr<T>());
auto thrust_ctx = thrust::cuda::par(allocator).on(reinterpret_cast<cudaStream_t>(stream_int));
thrust::sort(thrust_ctx, ptr_tr, ptr_tr + uniq_size);
auto new_end = thrust::unique(thrust_ctx, ptr_tr, ptr_tr + uniq_size);
num_out_act = new_end - ptr_tr - 1;
}});
return num_out_act;
""")
return code.ret("int")
@pccm.pybind.mark
@pccm.static_function
def get_indice_pairs_implicit_gemm(self):
code = pccm.code()
code.arg("allocator", "ExternalAllocator&")
code.arg("indices", "tv::Tensor")
code.arg("batch_size", "int")
code.arg("input_dims", f"std::vector<int>")
code.arg("algo", "int")
code.arg("ksize, stride, padding, dilation, out_padding", f"std::vector<int>")
code.arg("subm, transposed, is_train", f"bool")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
throw std::runtime_error("this function can only be used with CUDA.")
""")
return code.ret("tv::Tensor")
code.raw(f"""
auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo);
int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>());
TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32");
std::vector<int64_t> input_dims_i64(input_dims.begin(), input_dims.end());
int64_t spatial_volume = std::accumulate(input_dims_i64.begin(),
input_dims_i64.end(), int64_t(1), std::multiplies<int64_t>());
bool use_int64_hash_k = spatial_volume >= int64_t(std::numeric_limits<int>::max());
tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32;
std::vector<int> out_shape;
if (!subm){{
if (transposed){{
out_shape = get_deconv_output_size(input_dims, ksize, stride, padding, dilation, out_padding);
}}else{{
out_shape = get_conv_output_size(input_dims, ksize, stride, padding, dilation);
}}
}}else{{
out_shape = input_dims;
}}
for (auto& v : out_shape){{
if (v <= 0){{
TV_THROW_RT_ERR("your out spatial shape", out_shape, "ratch zero!, input shape:", input_dims);
}}
}}
TV_ASSERT_RT_ERR(conv_algo == tv::gemm::SparseConvAlgo::kMaskImplicitGemm ||
conv_algo == tv::gemm::SparseConvAlgo::kMaskSplitImplicitGemm, "only support implicit gemm");
bool is_mask_split = conv_algo == tv::gemm::SparseConvAlgo::kMaskSplitImplicitGemm;
int mask_split_count = is_mask_split ? 2 : 1;
tv::Tensor pair;
if (subm){{
pair = allocator.full_int({pccm.literal(SpconvAllocatorKeys.Pair)},
{{2, kv, indices.dim(0)}}, -1, indices.dtype(), indices.device());
}}else{{
pair = allocator.full_int({pccm.literal(SpconvAllocatorKeys.Pair)},
{{kv, indices.dim(0)}}, -1, indices.dtype(), indices.device());
}}
auto indice_num_per_loc = allocator.zeros({pccm.literal(SpconvAllocatorKeys.IndiceNumPerLoc)},
{{kv}}, indices.dtype(), indices.device());
tv::Tensor mask_tensor = tv::zeros({{mask_split_count}}, tv::uint32, -1);
auto mask_tensor_ptr = mask_tensor.data_ptr<uint32_t>();
if (is_mask_split){{
auto kv_div_2 = kv / 2;
auto remain = kv - kv_div_2;
uint64_t mask_np_1 = 1;
uint64_t first = ((mask_np_1 << remain) - 1);
uint64_t second = ((mask_np_1 << kv_div_2) - 1) << remain;
mask_tensor_ptr[0] = uint32_t(first);
mask_tensor_ptr[1] = uint32_t(second);
}}
else{{
mask_tensor_ptr[1] = 0xffffffff;
}}
tv::Tensor out_inds;
ThrustAllocator thrustalloc(allocator);
if (subm){{
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
out_inds = indices;
int num_points = out_inds.dim(0);
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{num_points * 2}}, tv::int64, 0);
hash_v_gurad = allocator.empty_guard({{num_points * 2}}, tv::int32, 0);
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_points * 2}}, tv::int32, 0);
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
auto pair_mask = allocator.empty({pccm.literal(SpconvAllocatorKeys.PairMask)},
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0);
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation, pair_mask, false, stream_int);
auto mask_argsort = allocator.empty({pccm.literal(SpconvAllocatorKeys.MaskArgSort)},
{{mask_split_count, out_inds.dim(0)}}, tv::uint32, 0);
for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int);
}}
}}else{{
auto pair_bwd = pair;
auto indice_pairs_uniq_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, indice_uniq_dtype, 0);
auto indice_pairs_uniq_bkp_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, indice_uniq_dtype, 0);
auto indice_pairs_uniq = indice_pairs_uniq_guard->tensor;
generate_conv_inds_mask_stage1(indices, pair_bwd, indice_pairs_uniq,
indice_num_per_loc, batch_size, out_shape, input_dims, ksize,
stride, padding, dilation, transposed, stream_int);
indice_pairs_uniq_bkp_guard->tensor.copy_(indice_pairs_uniq, tvctx);
// TODO pytorch unique may be faster?
int num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int) - 1;
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_act_out);
out_inds = allocator.empty({pccm.literal(SpconvAllocatorKeys.OutIndices)},
{{num_act_out, indices.dim(1)}}, indices.dtype(), 0);
auto pair_fwd = allocator.full_int({pccm.literal(SpconvAllocatorKeys.PairFwd)},
{{kv, num_act_out}}, -1, indices.dtype(), indices.device());
auto pair_mask_fwd = allocator.zeros({pccm.literal(SpconvAllocatorKeys.PairMask)},
{{mask_split_count, num_act_out}}, tv::uint32, 0);
auto pair_mask_bwd = tv::Tensor();
if (is_train){{
pair_mask_bwd = allocator.zeros({pccm.literal(SpconvAllocatorKeys.PairMaskBwd)},
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0);
}}
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{num_act_out * 2}}, tv::int64, 0);
hash_v_gurad = allocator.empty_guard({{num_act_out * 2}}, tv::int32, 0);
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_act_out * 2}}, tv::int32, 0);
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
generate_conv_inds_mask_stage2(indices, hash_k, hash_v, pair_fwd, pair_bwd,
indice_pairs_uniq, indice_pairs_uniq_bkp_guard->tensor,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int);
auto mask_argsort_fwd = allocator.empty({pccm.literal(SpconvAllocatorKeys.MaskArgSort)},
{{mask_split_count, out_inds.dim(0)}}, tv::uint32, 0);
tv::Tensor mask_argsort_bwd = tv::Tensor();
if (is_train){{
mask_argsort_bwd = allocator.zeros({pccm.literal(SpconvAllocatorKeys.MaskArgSortBwd)},
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0);
}}
if (is_mask_split){{
for (int j = 0; j < mask_split_count; ++j){{
if (!is_train){{
sort_1d_by_key_split_allocator_v2(pair_mask_fwd[j], thrustalloc,
mask_tensor[j], mask_argsort_fwd[j], stream_int);
}}else{{
sort_1d_by_key_split_allocator_v2(pair_mask_fwd[j], thrustalloc,
mask_tensor[j], mask_argsort_fwd[j], stream_int);
sort_1d_by_key_split_allocator_v2(pair_mask_bwd[j], thrustalloc,
mask_tensor[j], mask_argsort_bwd[j], stream_int);
}}
}}
}}else{{
if (!is_train){{
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int);
}}else{{
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int);
sort_1d_by_key_allocator_v2(pair_mask_bwd[0], thrustalloc,
mask_argsort_bwd[0], stream_int);
}}
}}
}}
return mask_tensor;
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.static_function
def get_indice_pairs(self):
code = pccm.code()
code.arg("allocator", "ExternalAllocator&")
code.arg("indices", "tv::Tensor")
code.arg("batch_size", "int")
code.arg("input_dims", f"std::vector<int>")
code.arg("algo", "int")
code.arg("ksize, stride, padding, dilation, out_padding", f"std::vector<int>")
code.arg("subm, transposed", f"bool")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
throw std::runtime_error("this function can only be used with CUDA.")
""")
return code
code.raw(f"""
int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>());
TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32");
auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo);
TV_ASSERT_RT_ERR(conv_algo == tv::gemm::SparseConvAlgo::kNative, "only support kNative");
std::vector<int64_t> input_dims_i64(input_dims.begin(), input_dims.end());
int64_t spatial_volume = std::accumulate(input_dims_i64.begin(),
input_dims_i64.end(), int64_t(1), std::multiplies<int64_t>());
bool use_int64_hash_k = spatial_volume >= int64_t(std::numeric_limits<int>::max());
tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32;
std::vector<int> out_shape;
if (!subm){{
if (transposed){{
out_shape = get_deconv_output_size(input_dims, ksize, stride, padding, dilation, out_padding);
}}else{{
out_shape = get_conv_output_size(input_dims, ksize, stride, padding, dilation);
}}
}}else{{
out_shape = input_dims;
}}
for (auto& v : out_shape){{
if (v <= 0){{
TV_THROW_RT_ERR("your out spatial shape", out_shape, "ratch zero!, input shape:", input_dims);
}}
}}
tv::Tensor pair;
pair = allocator.full_int({pccm.literal(SpconvAllocatorKeys.Pair)},
{{2, kv, indices.dim(0)}}, -1, indices.dtype(), indices.device());
auto indice_num_per_loc = allocator.zeros({pccm.literal(SpconvAllocatorKeys.IndiceNumPerLoc)},
{{kv}}, indices.dtype(), indices.device());
tv::Tensor out_inds;
""")
with code.if_("subm"):
code.raw(f"""
if (indices.is_cpu()){{
generate_subm_conv_inds_cpu(indices, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation);
}}
""")
if not CUMM_CPU_ONLY_BUILD:
code.raw(f"""
else {{
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
out_inds = indices;
int num_points = out_inds.dim(0);
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{num_points * 2}}, tv::int64, 0);
hash_v_gurad = allocator.empty_guard({{num_points * 2}}, tv::int32, 0);
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_points * 2}}, tv::int32, 0);
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation, tv::Tensor(), false, stream_int);
}}
""")
else:
code.raw(f"""
else {{
TV_THROW_RT_ERR("not implemented for CPU ONLY build.")
}}
""")
with code.else_():
code.raw(f"""
if (indices.is_cpu()){{
out_inds = allocator.empty({pccm.literal(SpconvAllocatorKeys.OutIndices)},
{{kv * indices.dim(0), indices.dim(1)}}, indices.dtype(), -1);
generate_conv_inds_cpu(indices, pair, out_inds, indice_num_per_loc,
batch_size, out_shape, input_dims, ksize,
stride, padding, dilation, transposed);
}}
""")
if not CUMM_CPU_ONLY_BUILD:
code.raw(f"""
else {{
ThrustAllocator thrustalloc(allocator);
auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
auto indice_pairs_uniq_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, indice_uniq_dtype, 0);
auto indice_pairs_uniq = indice_pairs_uniq_guard->tensor;
auto indice_pairs_uniq_bkp_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, indice_uniq_dtype, 0);
generate_conv_inds_stage1(indices, pair, indice_pairs_uniq,
indice_num_per_loc, batch_size, out_shape, input_dims, ksize,
stride, padding, dilation, transposed, stream_int);
indice_pairs_uniq_bkp_guard->tensor.copy_(indice_pairs_uniq, tvctx);
// TODO pytorch unique may be faster?
int num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int) - 1;
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_act_out);
out_inds = allocator.empty({pccm.literal(SpconvAllocatorKeys.OutIndices)},
{{num_act_out, indices.dim(1)}}, indices.dtype(), 0);
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{num_act_out * 2}}, tv::int64, 0);
hash_v_gurad = allocator.empty_guard({{num_act_out * 2}}, tv::int32, 0);
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_act_out * 2}}, tv::int32, 0);
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
generate_conv_inds_stage2(indices, hash_k, hash_v, pair,
indice_pairs_uniq, indice_pairs_uniq_bkp_guard->tensor,
out_inds, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int);
}}
""")
else:
code.raw(f"""
else {{
TV_THROW_RT_ERR("not implemented for CPU ONLY build.")
}}
""")
code.raw(f"""
return;
""")
return code
@pccm.pybind.mark
@pccm.static_function
def test_allocator(self):
code = pccm.code()
code.arg("allocator", "ExternalAllocator&")
code.raw(f"""
auto guard = allocator.zeros_guard({{1, 2, 3}}, tv::int32, 0);
tv::ssprint("????");
""")
return code
\ No newline at end of file
import pccm
from cumm.common import TensorView, TensorViewCPU, TensorViewKernel, ThrustLib
class ExternalAllocatorGuard(pccm.Class):
def __init__(self):
super().__init__()
self.add_dependency(TensorView)
self.add_member("tensor", "tv::Tensor")
self.add_member("free_func", "std::function<void(tv::Tensor)>")
@pccm.constructor
def ctor(self):
code = pccm.code()
code.arg("ten", "tv::Tensor")
code.arg("free_func", "std::function<void(tv::Tensor)>")
code.ctor_init("tensor", "ten")
code.ctor_init("free_func", "free_func")
return code
@pccm.constructor
def dctor(self):
code = pccm.code()
return code
@pccm.destructor
def dtor(self):
code = pccm.code()
code.raw(f"""
if (!tensor.empty() && free_func){{
free_func(tensor);
}}
""")
return code
class ExternalAllocator(pccm.Class):
def __init__(self):
super().__init__()
self.add_dependency(TensorView, ExternalAllocatorGuard)
self.use_shared = True
self.ptr_type = "unique"
if self.use_shared:
self.ptr_type = "shared"
self.add_typedef("guard_t", f"std::{self.ptr_type}_ptr<ExternalAllocatorGuard>")
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True, pure_virtual=True)
def zeros(self):
code = pccm.code()
code.arg("name", "std::string")
code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int")
code.arg("device", "int")
return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True, pure_virtual=True)
def empty(self):
code = pccm.code()
code.arg("name", "std::string")
code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int")
code.arg("device", "int")
return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True, pure_virtual=True)
def full_int(self):
code = pccm.code()
code.arg("name", "std::string")
code.arg("shape", "std::vector<int64_t>")
code.arg("value", "int")
code.arg("dtype", "int")
code.arg("device", "int")
return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True, pure_virtual=True)
def full_float(self):
code = pccm.code()
code.arg("name", "std::string")
code.arg("shape", "std::vector<int64_t>")
code.arg("value", "float")
code.arg("dtype", "int")
code.arg("device", "int")
return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True, pure_virtual=True)
def free(self):
code = pccm.code()
code.arg("ten", "tv::Tensor")
return code
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True, pure_virtual=True)
def free_noexcept(self):
code = pccm.code()
code.arg("ten", "tv::Tensor")
return code
@pccm.member_function
def zeros_guard(self):
code = pccm.code()
code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int")
code.arg("device", "int")
code.raw(f"""
// "" means temp memory
auto ten = zeros("", shape, dtype, device);
return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
""")
return code.ret(f"std::{self.ptr_type}_ptr<ExternalAllocatorGuard>")
@pccm.member_function
def empty_guard(self):
code = pccm.code()
code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int")
code.arg("device", "int")
code.raw(f"""
auto ten = empty("", shape, dtype, device);
return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
""")
return code.ret(f"std::{self.ptr_type}_ptr<ExternalAllocatorGuard>")
@pccm.member_function
def full_int_guard(self):
code = pccm.code()
code.arg("shape", "std::vector<int64_t>")
code.arg("value", "int")
code.arg("dtype", "int")
code.arg("device", "int")
code.raw(f"""
auto ten = full_int("", shape, value, dtype, device);
return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
""")
return code.ret(f"std::{self.ptr_type}_ptr<ExternalAllocatorGuard>")
@pccm.member_function
def full_float_guard(self):
code = pccm.code()
code.arg("shape", "std::vector<int64_t>")
code.arg("value", "int")
code.arg("dtype", "int")
code.arg("device", "int")
code.raw(f"""
auto ten = full_float("", shape, value, dtype, device);
return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor t){{
this->free(t);
}});
""")
return code.ret(f"std::{self.ptr_type}_ptr<ExternalAllocatorGuard>")
class ThrustAllocator(pccm.Class):
def __init__(self):
super().__init__()
self.add_dependency(TensorView, ExternalAllocator)
self.add_include("functional", "memory")
self.add_member("allocator_", "ExternalAllocator&",)
self.add_typedef("value_type", "char")
@pccm.constructor
def ctor(self):
code = pccm.code()
code.arg("allocator", "ExternalAllocator&")
code.ctor_init("allocator_", "allocator")
return code
@pccm.member_function
def allocate(self):
code = pccm.FunctionCode()
code.arg("num_bytes", "std::ptrdiff_t")
code.ret("char*")
code.raw(f"""
auto ten = allocator_.empty("", {{num_bytes}}, tv::uint8, 0);
return reinterpret_cast<char*>(ten.raw_data());
""")
return code
@pccm.member_function
def deallocate(self):
code = pccm.FunctionCode()
code.arg("ptr", "char *")
code.arg("num_bytes", "size_t")
code.raw(f"""
return allocator_.free_noexcept(tv::from_blob(ptr, {{num_bytes}}, tv::uint8, 0));
""")
return code
import pccm
from cumm.gemm.main import GemmMainUnitTest
from cumm.conv.main import ConvMainUnitTest
from .alloc import ExternalAllocator
from spconv.core import ConvAlgo
from spconv.constants import SpconvAllocatorKeys
from cumm.constants import CUMM_CPU_ONLY_BUILD
from cumm.common import GemmBasicHost, TensorView, NlohmannJson
class GemmTuneResult(pccm.Class, pccm.pybind.PybindClassMixin):
def __init__(self):
super().__init__()
self.add_dependency(GemmBasicHost, TensorView)
self.add_pybind_member("algo_desp", "tv::gemm::GemmAlgoDesp")
self.add_pybind_member("arch", "std::tuple<int, int>")
self.add_pybind_member("splitk", "int")
@pccm.pybind.mark
@pccm.member_function
def is_valid(self):
code = pccm.code()
code.raw(f"return splitk > 0 && std::get<0>(arch) > 0")
return code
@pccm.pybind.mark
@pccm.constructor
def defaultctor(self):
code = pccm.code()
code.ctor_init("algo_desp", "tv::gemm::GemmAlgoDesp()")
code.ctor_init("arch", "std::make_tuple(-1, -1)")
code.ctor_init("splitk", "-1")
return code
@pccm.pybind.mark
@pccm.constructor
def ctor(self):
code = pccm.code()
code.arg("algo_desp",
"tv::gemm::GemmAlgoDesp",
pyanno="cumm.tensorview.gemm.GemmAlgoDesp")
code.arg("arch", "std::tuple<int, int>")
code.arg("splitk", "int")
code.ctor_init("algo_desp", "algo_desp")
code.ctor_init("arch", "arch")
code.ctor_init("splitk", "splitk")
return code
class ConvTuneResult(pccm.Class, pccm.pybind.PybindClassMixin):
def __init__(self):
super().__init__()
self.add_dependency(GemmBasicHost, TensorView)
self.add_pybind_member("algo_desp", "tv::gemm::ConvAlgoDesp")
self.add_pybind_member("arch", "std::tuple<int, int>")
self.add_pybind_member("splitk", "int")
@pccm.pybind.mark
@pccm.constructor
def defaultctor(self):
code = pccm.code()
code.ctor_init("algo_desp", "tv::gemm::ConvAlgoDesp()")
code.ctor_init("arch", "std::make_tuple(-1, -1)")
code.ctor_init("splitk", "-1")
return code
@pccm.pybind.mark
@pccm.constructor
def ctor(self):
code = pccm.code()
code.arg("algo_desp",
"tv::gemm::ConvAlgoDesp",
pyanno="cumm.tensorview.gemm.ConvAlgoDesp")
code.arg("arch", "std::tuple<int, int>")
code.arg("splitk", "int")
code.ctor_init("algo_desp", "algo_desp")
code.ctor_init("arch", "arch")
code.ctor_init("splitk", "splitk")
return code
@pccm.pybind.mark
@pccm.member_function
def is_valid(self):
code = pccm.code()
code.raw(f"return splitk > 0 && std::get<0>(arch) > 0")
return code
class GemmTunerSimple(pccm.ParameterizedClass):
def __init__(self, gemm_cu: GemmMainUnitTest, conv_cu: ConvMainUnitTest):
super().__init__()
self.add_dependency(ExternalAllocator, GemmTuneResult,
ConvTuneResult, TensorView)
self.add_param_class("gemm", gemm_cu, "GemmMain")
self.add_param_class("conv", conv_cu, "ConvMain")
self.add_include("tensorview/utility/tuplehash.h")
self.add_member("desps_", "std::vector<tv::gemm::GemmAlgoDesp>")
self.add_member("nvrtc_progs_", "std::unordered_map<std::string, tv::NVRTCProgram>")
self.add_member("nvrtc_caches_", "std::unordered_map<std::tuple<std::string, int, int, std::uintptr_t>, tv::NVRTCModule>")
@pccm.pybind.mark
@pccm.constructor
def ctor(self):
code = pccm.code()
code.arg("desps", "std::vector<tv::gemm::GemmAlgoDesp>")
code.arg("nvrtc_progs", "std::unordered_map<std::string, std::string>")
code.ctor_init("desps_", "desps")
code.raw(f"""
for (auto& v : nvrtc_progs){{
const uint8_t* code_ptr = reinterpret_cast<const uint8_t*>(v.second.c_str());
nvrtc_progs_.insert(v.first, tv::NVRTCProgram::from_binary(code_ptr, v.second.size()));
}}
""")
return code
@pccm.member_function
def get_all_available(self):
code = pccm.code()
code.arg("a, b, c", "tv::Tensor")
code.arg("trans_a, trans_b, trans_c", "bool")
code.arg("arch", "std::tuple<int, int>")
code.arg("nvrtc_progs", "std::unordered_map<std::string, std::string>")
code.ctor_init("desps_", "desps")
code.raw(f"""
for (auto& v : nvrtc_progs){{
const uint8_t* code_ptr = reinterpret_cast<const uint8_t*>(v.second.c_str());
nvrtc_progs_.insert(v.first, tv::NVRTCProgram::from_binary(code_ptr, v.second.size()));
}}
""")
return code
class ConvGemmOps(pccm.ParameterizedClass):
def __init__(self, gemm_cu: GemmMainUnitTest, conv_cu: ConvMainUnitTest):
super().__init__()
self.add_dependency(ExternalAllocator, GemmTuneResult,
ConvTuneResult)
self.add_param_class("gemm", gemm_cu, "GemmMain")
self.add_param_class("conv", conv_cu, "ConvMain")
@pccm.pybind.mark
@pccm.static_function
def indice_conv(self):
"""1. this function need to take a out features
that from subm first mm.
2. this function don't support CPU.
"""
code = pccm.code()
code.arg("allocator", "ExternalAllocator&")
code.arg("out_features_after_mm", "tv::Tensor")
code.arg("features, filters, indice_pairs", "tv::Tensor")
code.arg("indice_pair_num", "tv::Tensor")
code.arg("num_activate_out", "int")
code.arg("inverse", "bool", "false")
code.arg("subm", "bool", "false")
code.arg("algo", "int", f"{ConvAlgo.Native.value}")
code.arg("filter_hwio", "bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
throw std::runtime_error("this function can only be used with CUDA.")
""")
return code.ret("tv::Tensor")
code.raw(f"""
TV_ASSERT_RT_ERR(!features.is_cpu(), "this function don't support cpu.")
int out_channel;
if (filter_hwio){{
out_channel = filters.dim(-1);
}}else{{
out_channel = filters.dim(-2);
}}
filters = filters.view(-1, filters.dim(-2), filters.dim(-1));
int kv = filters.dim(0);
int kv_center = kv / 2;
tv::Tensor out_features;
if (kv == 1 && subm){{
return;
}}
auto indice_pair_num_cpu = indice_pair_num.cpu();
auto indice_pair_num_cpu_ptr = indice_pair_num_cpu.data_ptr<int>();
int maxnhot = 0;
bool all_zero = true;
for (int i = 0; i < kv; ++i){{
if (indice_pair_num_cpu_ptr[i] != 0){{
all_zero = false;
maxnhot = std::max(maxnhot, indice_pair_num_cpu_ptr[i]);
}}
}}
if (subm && all_zero){{
return;
}}
bool inited = subm;
auto a = features;
auto c = out_features;
auto pair_in = indice_pairs[int(inverse)];
auto pair_out = indice_pairs[int(!inverse)];
""")
return code
...@@ -23,13 +23,8 @@ class OMPLib(pccm.Class): ...@@ -23,13 +23,8 @@ class OMPLib(pccm.Class):
self.add_dependency(TensorView) self.add_dependency(TensorView)
self.add_include("tensorview/parallel/all.h") self.add_include("tensorview/parallel/all.h")
if compat.InWindows: if compat.InWindows:
self.build_meta.add_cflags("cl", "/openmp") self.build_meta.add_public_cflags("cl", "/openmp")
else: else:
self.build_meta.add_cflags("g++", "-fopenmp") self.build_meta.add_public_cflags("g++", "-fopenmp")
self.build_meta.add_cflags("clang++", "-fopenmp") self.build_meta.add_public_cflags("clang++", "-fopenmp")
if "g++" not in self.build_meta.compiler_to_ldflags: self.build_meta.add_ldflags("g++,clang++", "-fopenmp")
self.build_meta.compiler_to_ldflags["g++"] = []
self.build_meta.compiler_to_ldflags["g++"].extend(["-fopenmp"])
if "clang++" not in self.build_meta.compiler_to_ldflags:
self.build_meta.compiler_to_ldflags["clang++"] = []
self.build_meta.compiler_to_ldflags["clang++"].extend(["-fopenmp"])
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import contextlib import contextlib
from cumm.conv.bases import ConvEnum
from cumm.gemm.core.metaarray import MetaArray, seq from cumm.gemm.core.metaarray import MetaArray, seq
from cumm import dtypes from cumm import dtypes
import pccm import pccm
...@@ -255,7 +254,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -255,7 +254,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
self.add_param_class("spinds", self.loc_iter, "ConvLocIter") self.add_param_class("spinds", self.loc_iter, "ConvLocIter")
self.add_param_class("spinds", problem, "ConvProblem") self.add_param_class("spinds", problem, "ConvProblem")
self.add_param_class("cudakers", CudaCommonKernel()) self.add_param_class("cudakers", CudaCommonKernel())
self.add_include("tensorview/hash/ops.h")
self.ndim = problem.ndim self.ndim = problem.ndim
self.dtype_indices = dtype_indices self.dtype_indices = dtype_indices
self.dtype_indices_uniq = dtype_indices self.dtype_indices_uniq = dtype_indices
...@@ -265,13 +264,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -265,13 +264,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def calc_conv_indices_stage1(self): def calc_conv_indices_stage1(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.targ("TIndiceUniq")
code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1]
code.arg("indices_in", f"const int*") # [N, ndim + 1] code.arg("indices_in", f"const int*") # [N, ndim + 1]
code.arg("indice_pairs", code.arg("indice_pairs",
f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] f"{self.dtype_indices}*") # [2, kernelProd, MaxSize]
code.arg("indice_pairs_for_uniq", code.arg("indice_pairs_for_uniq",
f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] f"TIndiceUniq*") # [2, kernelProd, MaxSize]
code.arg("indice_num_per_loc", f"int*") # [kernelProd] code.arg("indice_num_per_loc", f"int*") # [kernelProd]
code.arg("num_indices_in", "int") code.arg("num_indices_in", "int")
...@@ -295,10 +295,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -295,10 +295,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
}} }}
if (valid){{ if (valid){{
int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset); int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
{self.dtype_indices} offset = loc_iter.layout_npq(npq_offset); int64_t offset = loc_iter.layout_npq(npq_offset);
if (old_num < indices_pair_size){{ if (old_num < indices_pair_size){{
indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i; indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i;
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + old_num] = offset; // indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + old_num] = offset;
indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + old_num] = offset; indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + old_num] = offset;
}} }}
}} }}
...@@ -314,7 +314,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -314,7 +314,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("table", f"TTable") # [N, ndim + 1] code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indices_out", f"int*") # [N, ndim + 1] code.arg("indices_out", f"int*") # [N, ndim + 1]
code.arg("indice_pairs_for_uniq", code.arg("indice_pairs_for_uniq",
f"const {self.dtype_indices}*") # [2, kernelProd, MaxSize] f"const typename TTable::key_type*") # [2, kernelProd, MaxSize]
code.arg("layout_npq", code.arg("layout_npq",
f"spinds::LayoutNPQ") # [2, kernelProd, MaxSize] f"spinds::LayoutNPQ") # [2, kernelProd, MaxSize]
...@@ -323,7 +323,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -323,7 +323,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.raw(f""" code.raw(f"""
for (int output_index : tv::KernelLoopX<int>(num_indices)) {{ for (int output_index : tv::KernelLoopX<int>(num_indices)) {{
{self.dtype_indices} output_coord_offset = indice_pairs_for_uniq[output_index]; auto output_coord_offset = indice_pairs_for_uniq[output_index];
layout_npq.inverse(output_coord_offset, indices_out + {self.ndim + 1} * output_index); layout_npq.inverse(output_coord_offset, indices_out + {self.ndim + 1} * output_index);
table.insert(output_coord_offset, output_index); table.insert(output_coord_offset, output_index);
}} }}
...@@ -334,20 +334,24 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -334,20 +334,24 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def calc_conv_indices_stage2(self): def calc_conv_indices_stage2(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.targ("TTable") code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1] code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indice_pairs_out_part", f"int*") # [2, kernelProd, MaxSize] code.arg("indice_pairs_uniq_before_sort", f"const typename TTable::key_type*") # [kernelProd, MaxSize]
code.arg("indice_pairs_out_part", f"int*") # [kernelProd, MaxSize]
code.arg("num_indices_in", "int") code.arg("num_indices_in", "int")
code.arg("indices_pair_size", "int") code.arg("indices_pair_size", "int")
# TODO use block instead of filter_offset? # TODO use block instead of filter_offset?
code.raw(f""" code.raw(f"""
int filter_offset = blockIdx.y; int filter_offset = blockIdx.y;
auto indice_pairs_out_part_filter = indice_pairs_out_part + filter_offset * indices_pair_size; auto indice_pairs_out_part_filter = indice_pairs_out_part + filter_offset * indices_pair_size;
auto indice_pairs_uniq_before_sort_filter = indice_pairs_uniq_before_sort + filter_offset * indices_pair_size;
for (int i : tv::KernelLoopX<int>(num_indices_in)) {{ for (int i : tv::KernelLoopX<int>(num_indices_in)) {{
{self.dtype_indices} output_coord_offset = indice_pairs_out_part_filter[i]; {self.dtype_indices} output_coord_offset = indice_pairs_uniq_before_sort_filter[i];
if (output_coord_offset > -1){{ if (output_coord_offset != std::numeric_limits<typename TTable::key_type>::max()){{
auto ptr = table.lookup_ptr(output_coord_offset); auto table_offset = table.lookup_offset(output_coord_offset);
if (ptr){{ if (table_offset != -1){{
indice_pairs_out_part_filter[i] = ptr->second; indice_pairs_out_part_filter[i] = table.value_ptr()[table_offset];
}} }}
}} }}
}} }}
...@@ -357,13 +361,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -357,13 +361,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def calc_conv_indices_stage1_mask(self): def calc_conv_indices_stage1_mask(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.targ("TIndiceUniq")
code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1]
code.arg("indices_in", f"const int*") # [N, ndim + 1] code.arg("indices_in", f"const int*") # [N, ndim + 1]
code.arg("indice_pairs_bwd", code.arg("indice_pairs_bwd",
f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] f"{self.dtype_indices}*") # [kernelProd, MaxSize]
code.arg("indice_pairs_for_uniq", code.arg("indice_pairs_for_uniq",
f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] f"TIndiceUniq*") # [2, kernelProd, MaxSize]
code.arg("indice_num_per_loc", f"int*") # [kernelProd] code.arg("indice_num_per_loc", f"int*") # [kernelProd]
code.arg("num_indices_in", "int") code.arg("num_indices_in", "int")
...@@ -386,12 +392,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -386,12 +392,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
}} }}
if (valid){{ if (valid){{
int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset); int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
{self.dtype_indices} output_coord_offset = loc_iter.layout_npq(npq_offset); TIndiceUniq output_coord_offset = loc_iter.layout_npq(npq_offset);
// if (old_num < indices_pair_size){{ // if (old_num < indices_pair_size){{
// indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i; // indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i;
// indice_pairs_bwd[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
indice_pairs_bwd[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset; // indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + old_num] = output_coord_offset;
indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + old_num] = output_coord_offset; indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
// }} // }}
}} }}
}} }}
...@@ -407,6 +413,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -407,6 +413,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
f"int*") # [kernelProd, MaxSize], inp -> out f"int*") # [kernelProd, MaxSize], inp -> out
code.arg("indice_pairs_bwd", code.arg("indice_pairs_bwd",
f"int*") # [kernelProd, MaxSize], out -> inp f"int*") # [kernelProd, MaxSize], out -> inp
code.arg("indice_pairs_uniq_before_sort", f"const typename TTable::key_type*") # [kernelProd, MaxSize]
code.arg("mask_fwd", f"uint32_t*") # [kernelProd] code.arg("mask_fwd", f"uint32_t*") # [kernelProd]
code.arg("mask_bwd", f"uint32_t*") # [kernelProd] code.arg("mask_bwd", f"uint32_t*") # [kernelProd]
...@@ -422,12 +429,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -422,12 +429,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto indice_pairs_fwd_filter = indice_pairs_fwd + filter_offset * num_indices_out; auto indice_pairs_fwd_filter = indice_pairs_fwd + filter_offset * num_indices_out;
auto indice_pairs_bwd_filter = indice_pairs_bwd + filter_offset * num_indices_in; auto indice_pairs_bwd_filter = indice_pairs_bwd + filter_offset * num_indices_in;
auto indice_pairs_uniq_before_sort_filter = indice_pairs_uniq_before_sort + filter_offset * num_indices_in;
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{ for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
{self.dtype_indices} output_coord_offset = indice_pairs_bwd_filter[input_index]; auto output_coord_offset = indice_pairs_uniq_before_sort_filter[input_index];
if (output_coord_offset > -1){{ if (output_coord_offset != std::numeric_limits<typename TTable::key_type>::max()){{
auto ptr = table.lookup_ptr(output_coord_offset);
if (ptr){{ auto table_offset = table.lookup_offset(output_coord_offset);
auto output_index = ptr->second; if (table_offset != -1){{
auto output_index = table.value_ptr()[table_offset];
atomicOr(mask_fwd + output_index, filter_mask_fwd); atomicOr(mask_fwd + output_index, filter_mask_fwd);
// atomicOr(mask_bwd + input_index, filter_mask_bwd); // atomicOr(mask_bwd + input_index, filter_mask_bwd);
indice_pairs_fwd_filter[output_index] = input_index; indice_pairs_fwd_filter[output_index] = input_index;
...@@ -465,11 +475,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -465,11 +475,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def calc_conv_indices_stage2_inference_mask(self): def calc_conv_indices_stage2_inference_mask(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.targ("TTable") code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1] code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indice_pairs_fwd", code.arg("indice_pairs_fwd",
f"int*") # [kernelProd, MaxSize], inp -> out f"int*") # [kernelProd, MaxSize], inp -> out
code.arg("indice_pairs_bwd", code.arg("indice_pairs_bwd",
f"int*") # [kernelProd, MaxSize], out -> inp f"int*") # [kernelProd, MaxSize], out -> inp
code.arg("indice_pairs_uniq_before_sort", f"const typename TTable::key_type*") # [kernelProd, MaxSize]
code.arg("mask_fwd", f"uint32_t*") # [kernelProd] code.arg("mask_fwd", f"uint32_t*") # [kernelProd]
code.arg("num_indices_in", "int") code.arg("num_indices_in", "int")
code.arg("num_indices_out", "int") code.arg("num_indices_out", "int")
...@@ -481,12 +494,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -481,12 +494,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto indice_pairs_fwd_filter = indice_pairs_fwd + filter_offset * num_indices_out; auto indice_pairs_fwd_filter = indice_pairs_fwd + filter_offset * num_indices_out;
auto indice_pairs_bwd_filter = indice_pairs_bwd + filter_offset * num_indices_in; auto indice_pairs_bwd_filter = indice_pairs_bwd + filter_offset * num_indices_in;
auto indice_pairs_uniq_before_sort_filter = indice_pairs_uniq_before_sort + filter_offset * num_indices_in;
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{ for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
{self.dtype_indices} output_coord_offset = indice_pairs_bwd_filter[input_index]; auto output_coord_offset = indice_pairs_uniq_before_sort_filter[input_index];
if (output_coord_offset > -1){{ if (output_coord_offset != std::numeric_limits<typename TTable::key_type>::max()){{
auto ptr = table.lookup_ptr(output_coord_offset); auto table_offset = table.lookup_offset(output_coord_offset);
if (ptr){{ if (table_offset != -1){{
auto output_index = ptr->second; auto output_index = table.value_ptr()[table_offset];
atomicOr(mask_fwd + output_index, filter_mask_fwd); atomicOr(mask_fwd + output_index, filter_mask_fwd);
indice_pairs_fwd_filter[output_index] = input_index; indice_pairs_fwd_filter[output_index] = input_index;
}} }}
...@@ -499,7 +513,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -499,7 +513,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def build_subm_conv_hash_table(self): def build_subm_conv_hash_table(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.targ("TTable") code.targ("TTable")
code.arg("table", f"TTable") # [N, ndim + 1] code.arg("table", f"TTable") # [N, ndim + 1]
code.arg("indices_in", f"const int*") # [N, ndim + 1] code.arg("indices_in", f"const int*") # [N, ndim + 1]
...@@ -509,8 +522,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -509,8 +522,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.raw(f""" code.raw(f"""
for (int i : tv::KernelLoopX<int>(num_indices)) {{ for (int i : tv::KernelLoopX<int>(num_indices)) {{
{self.dtype_indices} index = layout_npq(indices_in + i * {self.ndim + 1}); table.insert(layout_npq(indices_in + i * {self.ndim + 1}), i);
table.insert(index, i);
}} }}
""") """)
return code return code
...@@ -518,11 +530,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -518,11 +530,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def clean_indices_uniq(self): def clean_indices_uniq(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.arg("indice_pairs_for_uniq", f"{self.dtype_indices}*") code.targ("T")
code.arg("size", f"{self.dtype_indices}") code.arg("indice_pairs_for_uniq", f"T*")
code.arg("size", f"size_t")
code.raw(f""" code.raw(f"""
for ({self.dtype_indices} i : tv::KernelLoopX<{self.dtype_indices}>(size)) {{ for (size_t i : tv::KernelLoopX<size_t>(size)) {{
indice_pairs_for_uniq[i] = std::numeric_limits<{self.dtype_indices}>::max(); indice_pairs_for_uniq[i] = std::numeric_limits<T>::max();
}} }}
""") """)
return code return code
...@@ -559,13 +572,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -559,13 +572,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
for (int i : tv::KernelLoopX<int>(num_indices_in)) {{ for (int i : tv::KernelLoopX<int>(num_indices_in)) {{
tv::array<int, {self.ndim + 1}> npq_offset; tv::array<int, {self.ndim + 1}> npq_offset;
if (loc_iter.query_npq_no_stride(indices_in + i * {self.ndim + 1}, npq_offset)){{ if (loc_iter.query_npq_no_stride(indices_in + i * {self.ndim + 1}, npq_offset)){{
{self.dtype_indices} offset = loc_iter.layout_npq(npq_offset); auto offset = loc_iter.layout_npq(npq_offset);
auto item = table.lookup(offset); // performance bound // auto item = table.lookup(offset); // performance bound
if (!item.empty()){{ auto table_offset = table.lookup_offset(offset); // performance bound
if (table_offset != -1){{
auto v = table.value_ptr()[table_offset];
int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset); int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i; indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i;
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + old_num] = item.second; indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + old_num] = v;
indice_pairs[filter_offset_mul_indices_pair_size_1 + old_num] = item.second; indice_pairs[filter_offset_mul_indices_pair_size_1 + old_num] = v;
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size_1 + old_num] = i; indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size_1 + old_num] = i;
}} }}
}} }}
...@@ -613,10 +628,11 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -613,10 +628,11 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
tv::array<int, {self.ndim + 1}> nhw_offset; tv::array<int, {self.ndim + 1}> nhw_offset;
// table: input indice coord to output index (or output indice coord to input index) // table: input indice coord to output index (or output indice coord to input index)
if (loc_iter.query_nhw(indices_in + output_index * {self.ndim + 1}, nhw_offset)){{ if (loc_iter.query_nhw(indices_in + output_index * {self.ndim + 1}, nhw_offset)){{
{self.dtype_indices} offset = loc_iter.layout_npq(nhw_offset); auto offset = loc_iter.layout_npq(nhw_offset);
auto item = table.lookup(offset); // auto item = table.lookup(offset);
if (!item.empty()) {{ auto table_offset = table.lookup_offset(offset); // performance bound
auto input_index = item.second; // we find a input indice idx. if (table_offset != -1){{
auto input_index = table.value_ptr()[table_offset]; // we find a input indice idx.
atomicOr(mask + output_index, filter_mask_out); atomicOr(mask + output_index, filter_mask_out);
atomicOr(mask + input_index, filter_mask_in); atomicOr(mask + input_index, filter_mask_in);
// for this output, we set correct input idx. // for this output, we set correct input idx.
...@@ -670,10 +686,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -670,10 +686,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
tv::array<int, {self.ndim + 1}> nhw_offset; tv::array<int, {self.ndim + 1}> nhw_offset;
// table: input indice coord to output index (or output indice coord to input index) // table: input indice coord to output index (or output indice coord to input index)
if (loc_iter.query_nhw(indices_in + output_index * {self.ndim + 1}, nhw_offset)){{ if (loc_iter.query_nhw(indices_in + output_index * {self.ndim + 1}, nhw_offset)){{
{self.dtype_indices} offset = loc_iter.layout_npq(nhw_offset); auto offset = loc_iter.layout_npq(nhw_offset);
auto item = table.lookup(offset); auto table_offset = table.lookup_offset(offset); // performance bound
if (!item.empty()) {{ if (table_offset != -1){{
auto input_index = item.second; // we find a input indice idx. auto input_index = table.value_ptr()[table_offset]; // we find a input indice idx.
atomicOr(mask1 + output_index, filter_mask_out); atomicOr(mask1 + output_index, filter_mask_out);
atomicOr(mask2 + input_index, filter_mask_in); atomicOr(mask2 + input_index, filter_mask_in);
// for this output, we set correct input idx. // for this output, we set correct input idx.
...@@ -706,10 +722,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -706,10 +722,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.raw(f""" code.raw(f"""
// TODO stream // TODO stream
// TODO handle num input == 0 // TODO handle num input == 0
int kv = tv::arrayops::prod(ksize); int kv = ksize.op<tv::arrayops::prod>();
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error"); TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
TV_ASSERT_RT_ERR(tv::arrayops::prod(input_dims) <= std::numeric_limits<{self.dtype_indices}>::max(),
"kernel volume must smaller than max value of {self.dtype_indices}");
// indice_pairs: [2, kv, indices.dim(0)] // indice_pairs: [2, kv, indices.dim(0)]
// indice_pairs_uniq: [indice_pairs.size() / 2 + 1] // indice_pairs_uniq: [indice_pairs.size() / 2 + 1]
tv::check_shape(indice_pairs, {{2, kv, indices.dim(0)}}); tv::check_shape(indice_pairs, {{2, kv, indices.dim(0)}});
...@@ -724,11 +738,17 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -724,11 +738,17 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem); ConvLocIter loc_iter(problem);
tv::cuda::Launch launcher_clean_uniq(uniq_size, reinterpret_cast<cudaStream_t>(stream_int)); tv::cuda::Launch launcher_clean_uniq(uniq_size, reinterpret_cast<cudaStream_t>(stream_int));
launcher_clean_uniq(clean_indices_uniq, indice_pairs_uniq.data_ptr<{self.dtype_indices}>(), uniq_size);
launcher_num_act_in(calc_conv_indices_stage1, loc_iter, indices.data_ptr<const int>(), tv::dispatch<int32_t, int64_t>(indice_pairs_uniq.dtype(), [&](auto I){{
indice_pairs.data_ptr<{self.dtype_indices}>(), using T = TV_DECLTYPE(I);
indice_pairs_uniq.data_ptr<{self.dtype_indices}>(), indice_num_per_loc.data_ptr<int>(), indices.dim(0), TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<T>::max(),
indice_pairs.dim(2), kv, transposed); "kernel volume must smaller than max value of T");
launcher_clean_uniq(clean_indices_uniq<T>, indice_pairs_uniq.data_ptr<T>(), uniq_size);
launcher_num_act_in(calc_conv_indices_stage1<T>, loc_iter, indices.data_ptr<const int>(),
indice_pairs.data_ptr<{self.dtype_indices}>(),
indice_pairs_uniq.data_ptr<T>(), indice_num_per_loc.data_ptr<int>(), indices.dim(0),
indice_pairs.dim(2), kv, transposed);
}});
// thrust::device_ptr<{self.dtype_indices}> ptr_tr(indice_pairs_uniq.data_ptr<{self.dtype_indices}>()); // thrust::device_ptr<{self.dtype_indices}> ptr_tr(indice_pairs_uniq.data_ptr<{self.dtype_indices}>());
// auto thrust_ctx = thrust::cuda::par.on(reinterpret_cast<cudaStream_t>(stream_int)); // auto thrust_ctx = thrust::cuda::par.on(reinterpret_cast<cudaStream_t>(stream_int));
// thrust::sort(thrust_ctx, ptr_tr, ptr_tr + uniq_size); // thrust::sort(thrust_ctx, ptr_tr, ptr_tr + uniq_size);
...@@ -745,11 +765,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -745,11 +765,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.arg("uniq_size", "int64_t") code.arg("uniq_size", "int64_t")
code.arg("stream_int", f"std::uintptr_t", "0") code.arg("stream_int", f"std::uintptr_t", "0")
code.raw(f""" code.raw(f"""
thrust::device_ptr<{self.dtype_indices}> ptr_tr(indice_pairs_uniq.data_ptr<{self.dtype_indices}>()); int num_out_act = 0;
auto thrust_ctx = thrust::cuda::par.on(reinterpret_cast<cudaStream_t>(stream_int)); tv::dispatch<int32_t, int64_t>(indice_pairs_uniq.dtype(), [&](auto I){{
thrust::sort(thrust_ctx, ptr_tr, ptr_tr + uniq_size); using T = TV_DECLTYPE(I);
auto new_end = thrust::unique(thrust_ctx, ptr_tr, ptr_tr + uniq_size); thrust::device_ptr<T> ptr_tr(indice_pairs_uniq.data_ptr<T>());
auto num_out_act = new_end - ptr_tr - 1; auto thrust_ctx = thrust::cuda::par.on(reinterpret_cast<cudaStream_t>(stream_int));
thrust::sort(thrust_ctx, ptr_tr, ptr_tr + uniq_size);
auto new_end = thrust::unique(thrust_ctx, ptr_tr, ptr_tr + uniq_size);
num_out_act = new_end - ptr_tr - 1;
}});
return num_out_act; return num_out_act;
""") """)
return code.ret("int") return code.ret("int")
...@@ -757,8 +781,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -757,8 +781,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.static_function @pccm.cuda.static_function
def generate_conv_inds_stage2(self): def generate_conv_inds_stage2(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.arg("indices, hashdata", "tv::Tensor") code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor")
code.arg("indice_pairs, indice_pairs_uniq, out_inds", "tv::Tensor") code.arg("indice_pairs, indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds", "tv::Tensor")
code.arg("num_out_act", "int") code.arg("num_out_act", "int")
code.arg("batch_size", "int") code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"tv::array<int, {self.ndim}>") code.arg("output_dims, input_dims", f"tv::array<int, {self.ndim}>")
...@@ -770,8 +794,11 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -770,8 +794,11 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto custream = reinterpret_cast<cudaStream_t>(stream_int); auto custream = reinterpret_cast<cudaStream_t>(stream_int);
// TODO stream // TODO stream
// TODO handle num input == 0 // TODO handle num input == 0
int kv = tv::arrayops::prod(ksize); int kv = ksize.op<tv::arrayops::prod>();
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error"); TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
TV_ASSERT_RT_ERR(hashdata_k.dtype() == indice_pairs_uniq.dtype(), "error");
TV_ASSERT_RT_ERR(hashdata_v.dtype() == tv::int32, "error");
// indice_pairs: [2, kv, indices.dim(0)] // indice_pairs: [2, kv, indices.dim(0)]
// indice_pairs_uniq: [indice_pairs.size() / 2 + 1] // indice_pairs_uniq: [indice_pairs.size() / 2 + 1]
// out_inds: [MaxSize, {self.ndim + 1}] // out_inds: [MaxSize, {self.ndim + 1}]
...@@ -787,22 +814,25 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -787,22 +814,25 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
// TODO handle invalid num_out_act // TODO handle invalid num_out_act
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_out_act); indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_out_act);
tv::cuda::Launch lanucher_build_hash(num_out_act, custream); tv::cuda::Launch lanucher_build_hash(num_out_act, custream);
using V = {self.dtype_indices}; tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{
using KeyType = {self.dtype_indices}; using V = {self.dtype_indices};
constexpr KeyType kEmptyKey = std::numeric_limits<KeyType>::max(); using K = TV_DECLTYPE(I);
using table_t = using table_t =
tv::hash::LinearHashTable<KeyType, V, tv::hash::Murmur3Hash<KeyType>, tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey, false>; tv::hash::default_empty_key_v<K>, false>;
using pair_t = typename table_t::value_type; TV_ASSERT_RT_ERR(hashdata_k.dim(0) >= num_out_act, "hash size not enough");
TV_ASSERT_RT_ERR(hashdata.dim(0) >= num_out_act, "hash size not enough"); table_t hash = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
table_t hash = table_t(hashdata.data_ptr<pair_t>(), hashdata.dim(0)); tv::hash::clear_map_split(hash, custream);
hash.clear(custream); // hash.clear(custream);
lanucher_build_hash(build_conv_hash_table<table_t>, hash, lanucher_build_hash(build_conv_hash_table<table_t>, hash,
out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const {self.dtype_indices}>(), out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const K>(),
loc_iter.layout_npq, num_out_act); loc_iter.layout_npq, num_out_act);
launcher_num_act_in(calc_conv_indices_stage2<table_t>, hash, launcher_num_act_in(calc_conv_indices_stage2<table_t>, hash,
indice_pairs[1].data_ptr<int>(), indices.dim(0), indice_pairs_uniq_before_sort.data_ptr<const K>(),
indice_pairs.dim(2)); indice_pairs[1].data_ptr<int>(),
indices.dim(0),
indice_pairs.dim(2));
}});
return num_out_act; return num_out_act;
""") """)
return code.ret("int") return code.ret("int")
...@@ -824,9 +854,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -824,9 +854,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code.raw(f""" code.raw(f"""
// TODO stream // TODO stream
// TODO handle num input == 0 // TODO handle num input == 0
int kv = tv::arrayops::prod(ksize); int kv = ksize.op<tv::arrayops::prod>();
TV_ASSERT_RT_ERR(tv::arrayops::prod(input_dims) <= std::numeric_limits<{self.dtype_indices}>::max(),
"kernel volume must smaller than max value of {self.dtype_indices}");
// indice_pairs_bwd: [kv, indices.dim(0)] // indice_pairs_bwd: [kv, indices.dim(0)]
// indice_pairs_uniq: [indice_pairs_bwd.size() + 1] // indice_pairs_uniq: [indice_pairs_bwd.size() + 1]
tv::check_shape(indice_pairs_bwd, {{kv, indices.dim(0)}}); tv::check_shape(indice_pairs_bwd, {{kv, indices.dim(0)}});
...@@ -842,20 +870,25 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -842,20 +870,25 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem); ConvLocIter loc_iter(problem);
tv::cuda::Launch launcher_clean_uniq(uniq_size, reinterpret_cast<cudaStream_t>(stream_int)); tv::cuda::Launch launcher_clean_uniq(uniq_size, reinterpret_cast<cudaStream_t>(stream_int));
launcher_clean_uniq(clean_indices_uniq, indice_pairs_uniq.data_ptr<{self.dtype_indices}>(), uniq_size); tv::dispatch<int32_t, int64_t>(indice_pairs_uniq.dtype(), [&](auto I){{
launcher_num_act_in(calc_conv_indices_stage1_mask, loc_iter, indices.data_ptr<const int>(), using T = TV_DECLTYPE(I);
indice_pairs_bwd.data_ptr<{self.dtype_indices}>(), TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<T>::max(),
indice_pairs_uniq.data_ptr<{self.dtype_indices}>(), indice_num_per_loc.data_ptr<int>(), indices.dim(0), "kernel volume must smaller than max value of T");
kv, transposed); launcher_clean_uniq(clean_indices_uniq<T>, indice_pairs_uniq.data_ptr<T>(), uniq_size);
launcher_num_act_in(calc_conv_indices_stage1_mask<T>, loc_iter, indices.data_ptr<const int>(),
indice_pairs_bwd.data_ptr<{self.dtype_indices}>(),
indice_pairs_uniq.data_ptr<T>(), indice_num_per_loc.data_ptr<int>(), indices.dim(0),
kv, transposed);
}});
""") """)
return code # .ret("int") return code # .ret("int")
@pccm.cuda.static_function @pccm.cuda.static_function
def generate_conv_inds_stage2_mask(self): def generate_conv_inds_stage2_mask(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.arg("indices, hashdata", "tv::Tensor") code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor")
code.arg( code.arg(
"indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq, out_inds", "indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds",
"tv::Tensor") "tv::Tensor")
code.arg("mask_fwd, mask_bwd", "tv::Tensor") code.arg("mask_fwd, mask_bwd", "tv::Tensor")
...@@ -870,12 +903,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -870,12 +903,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto custream = reinterpret_cast<cudaStream_t>(stream_int); auto custream = reinterpret_cast<cudaStream_t>(stream_int);
// TODO stream // TODO stream
// TODO handle num input == 0 // TODO handle num input == 0
int kv = tv::arrayops::prod(ksize); int kv = ksize.op<tv::arrayops::prod>();
// indice_pairs_bwd: [kv, indices.dim(0)] // indice_pairs_bwd: [kv, indices.dim(0)]
// indice_pairs_fwd: [kv, out_inds.dim(0)] // indice_pairs_fwd: [kv, out_inds.dim(0)]
auto ctx = tv::Context(); auto ctx = tv::Context();
ctx.set_cuda_stream(custream); ctx.set_cuda_stream(custream);
TV_ASSERT_RT_ERR(hashdata_k.dtype() == indice_pairs_uniq.dtype(), "error");
TV_ASSERT_RT_ERR(hashdata_v.dtype() == tv::int32, "error");
// out_inds: [MaxSize, {self.ndim + 1}] // out_inds: [MaxSize, {self.ndim + 1}]
// auto timer = tv::CudaContextTimer<>(); // auto timer = tv::CudaContextTimer<>();
tv::check_shape(indice_pairs_bwd, {{kv, indices.dim(0)}}); tv::check_shape(indice_pairs_bwd, {{kv, indices.dim(0)}});
...@@ -892,45 +926,48 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -892,45 +926,48 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
// TODO handle invalid num_out_act // TODO handle invalid num_out_act
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_out_act); indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_out_act);
tv::cuda::Launch lanucher_build_hash(num_out_act, custream); tv::cuda::Launch lanucher_build_hash(num_out_act, custream);
using V = {self.dtype_indices}; tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{
using KeyType = {self.dtype_indices}; using V = {self.dtype_indices};
constexpr KeyType kEmptyKey = std::numeric_limits<KeyType>::max(); using K = TV_DECLTYPE(I);
using table_t = using table_t =
tv::hash::LinearHashTable<KeyType, V, tv::hash::Murmur3Hash<KeyType>, tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey, false>; tv::hash::default_empty_key_v<K>, false>;
using pair_t = typename table_t::value_type; TV_ASSERT_RT_ERR(hashdata_k.dim(0) >= num_out_act, "hash size not enough");
TV_ASSERT_RT_ERR(hashdata.dim(0) >= num_out_act, "hash size not enough"); table_t hash = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
table_t hash = table_t(hashdata.data_ptr<pair_t>(), hashdata.dim(0)); tv::hash::clear_map_split(hash, custream);
hash.clear(custream);
lanucher_build_hash(build_conv_hash_table<table_t>, hash, lanucher_build_hash(build_conv_hash_table<table_t>, hash,
out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const {self.dtype_indices}>(), out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const K>(),
loc_iter.layout_npq, num_out_act); loc_iter.layout_npq, num_out_act);
if (!mask_bwd.empty()){{ if (!mask_bwd.empty()){{
// auto timer = tv::CudaContextTimer<>(); // auto timer = tv::CudaContextTimer<>();
launcher_num_act_in(calc_conv_indices_stage2_mask<table_t>, hash, launcher_num_act_in(calc_conv_indices_stage2_mask<table_t>, hash,
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(), indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
mask_fwd.data_ptr<uint32_t>(), mask_bwd.data_ptr<uint32_t>(), indice_pairs_uniq_before_sort.data_ptr<K>(),
indice_pairs_bwd.dim(1), indice_pairs_fwd.dim(1)); mask_fwd.data_ptr<uint32_t>(), mask_bwd.data_ptr<uint32_t>(),
// tv::ssprint("calc_conv_indices_stage2_mask", timer.report() / 1000.0); indice_pairs_bwd.dim(1), indice_pairs_fwd.dim(1));
launcher_num_act_in_no_y(calc_conv_indices_stage2_mask_output, indice_pairs_bwd.data_ptr<int>(), // tv::ssprint("calc_conv_indices_stage2_mask", timer.report() / 1000.0);
mask_bwd.data_ptr<uint32_t>(), launcher_num_act_in_no_y(calc_conv_indices_stage2_mask_output, indice_pairs_bwd.data_ptr<int>(),
indice_pairs_bwd.dim(1), kv); mask_bwd.data_ptr<uint32_t>(),
// tv::ssprint("calc_conv_indices_stage2_mask_output", timer.report() / 1000.0); indice_pairs_bwd.dim(1), kv);
if (mask_fwd.dim(0) == 2){{ // tv::ssprint("calc_conv_indices_stage2_mask_output", timer.report() / 1000.0);
mask_fwd[1].copy_(mask_fwd[0], ctx); if (mask_fwd.dim(0) == 2){{
}} mask_fwd[1].copy_(mask_fwd[0], ctx);
if (mask_bwd.dim(0) == 2){{ }}
mask_bwd[1].copy_(mask_bwd[0], ctx); if (mask_bwd.dim(0) == 2){{
}} mask_bwd[1].copy_(mask_bwd[0], ctx);
}}else{{ }}
launcher_num_act_in(calc_conv_indices_stage2_inference_mask<table_t>, hash, }}else{{
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(), launcher_num_act_in(calc_conv_indices_stage2_inference_mask<table_t>, hash,
mask_fwd.data_ptr<uint32_t>(), indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
indice_pairs_bwd.dim(1), indice_pairs_fwd.dim(1)); indice_pairs_uniq_before_sort.data_ptr<K>(),
if (mask_fwd.dim(0) == 2){{ mask_fwd.data_ptr<uint32_t>(),
mask_fwd[1].copy_(mask_fwd[0], ctx); indice_pairs_bwd.dim(1), indice_pairs_fwd.dim(1));
if (mask_fwd.dim(0) == 2){{
mask_fwd[1].copy_(mask_fwd[0], ctx);
}}
}} }}
}} }});
return num_out_act; return num_out_act;
""") """)
return code.ret("int") return code.ret("int")
...@@ -938,7 +975,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -938,7 +975,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.static_function @pccm.cuda.static_function
def generate_subm_conv_inds(self): def generate_subm_conv_inds(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.arg("indices, hashdata", "tv::Tensor") code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor")
code.arg("indice_pairs, out_inds, indice_num_per_loc", "tv::Tensor") code.arg("indice_pairs, out_inds, indice_num_per_loc", "tv::Tensor")
code.arg("batch_size", "int") code.arg("batch_size", "int")
code.arg("input_dims", f"tv::array<int, {self.ndim}>") code.arg("input_dims", f"tv::array<int, {self.ndim}>")
...@@ -953,7 +990,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -953,7 +990,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto ctx = tv::Context(); auto ctx = tv::Context();
ctx.set_cuda_stream(custream); ctx.set_cuda_stream(custream);
if (!indice_pair_mask.empty()){{ if (!indice_pair_mask.empty()){{
TV_ASSERT_INVALID_ARG(tv::arrayops::prod(ksize) < 32, "for now only support 32bit mask"); TV_ASSERT_INVALID_ARG(ksize.op<tv::arrayops::prod>() <= 32, "for now only support 32bit mask");
}} }}
// TODO stream // TODO stream
// TODO handle num input == 0 // TODO handle num input == 0
...@@ -963,7 +1000,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -963,7 +1000,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
stride[i] = 1; stride[i] = 1;
padding[i] = (ksize[i] / 2) * dilation[i]; padding[i] = (ksize[i] / 2) * dilation[i];
}} }}
int kv = tv::arrayops::prod(ksize); int kv = ksize.op<tv::arrayops::prod>();
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error"); TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
// indice_pairs: [2, kv, indices.dim(0)] // indice_pairs: [2, kv, indices.dim(0)]
// out_inds: [MaxSize, {self.ndim + 1}] // out_inds: [MaxSize, {self.ndim + 1}]
...@@ -972,53 +1009,55 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -972,53 +1009,55 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
tv::cuda::Launch launcher_num_act_in(indices.dim(0), custream); tv::cuda::Launch launcher_num_act_in(indices.dim(0), custream);
launcher_num_act_in.blocks.y = (kv / 2) + 1; launcher_num_act_in.blocks.y = (kv / 2) + 1;
// launcher_num_act_in.blocks.y = kv; // launcher_num_act_in.blocks.y = kv;
TV_ASSERT_RT_ERR(tv::arrayops::prod(input_dims) <= std::numeric_limits<{self.dtype_indices}>::max(),
"kernel volume must smaller than max value of {self.dtype_indices}");
ConvProblem problem(batch_size, 1, 1, input_dims, input_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, input_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem); ConvLocIter loc_iter(problem);
tv::cuda::Launch lanucher_build_hash(indices.dim(0), custream); tv::cuda::Launch lanucher_build_hash(indices.dim(0), custream);
using V = {self.dtype_indices}; tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{
using KeyType = {self.dtype_indices}; using V = {self.dtype_indices};
constexpr KeyType kEmptyKey = std::numeric_limits<KeyType>::max(); using K = TV_DECLTYPE(I);
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<K>::max(),
using table_t = "kernel volume must smaller than max value of K");
tv::hash::LinearHashTable<KeyType, V, tv::hash::Murmur3Hash<KeyType>,
kEmptyKey, false>; using table_t =
using pair_t = typename table_t::value_type; tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
TV_ASSERT_RT_ERR(hashdata.dim(0) >= indices.dim(0), "hash size not enough"); tv::hash::default_empty_key_v<K>, false>;
table_t hash = table_t(hashdata.data_ptr<pair_t>(), hashdata.dim(0)); TV_ASSERT_RT_ERR(hashdata_k.dim(0) >= indices.dim(0), "hash size not enough");
hash.clear(custream); table_t hash = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
// tv::ssprint("clear hash time", hashdata.dim(0), timer.report() / 1000.0); tv::hash::clear_map_split(hash, custream);
lanucher_build_hash(build_subm_conv_hash_table<table_t>, hash, indices.data_ptr<const int>(), lanucher_build_hash(build_subm_conv_hash_table<table_t>, hash, indices.data_ptr<const int>(),
loc_iter.layout_npq, indices.dim(0)); loc_iter.layout_npq, indices.dim(0));
// tv::ssprint("build_hash time", timer.report() / 1000.0); // tv::ssprint("build_hash time", timer.report() / 1000.0);
if (!indice_pair_mask.empty()){{ if (!indice_pair_mask.empty()){{
TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() == 2, "error"); TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() == 2, "error");
if (indice_pair_mask.dim(0) == 2){{ if (indice_pair_mask.dim(0) == 2){{
auto mask_0 = indice_pair_mask[0]; auto mask_0 = indice_pair_mask[0];
tv::cuda::Launch lanucher_fill(mask_0.size(), custream); tv::cuda::Launch lanucher_fill(mask_0.size(), custream);
lanucher_fill(cudakers::fill_kernel<uint32_t>, mask_0.data_ptr<uint32_t>(), (1 << (kv / 2)), mask_0.size()); lanucher_fill(cudakers::fill_kernel<uint32_t>, mask_0.data_ptr<uint32_t>(), (1 << (kv / 2)), mask_0.size());
indice_pair_mask[1].zero_(ctx); indice_pair_mask[1].zero_(ctx);
auto kernel = &calc_subm_conv_indices_split_mask<table_t>; auto kernel = &calc_subm_conv_indices_split_mask<table_t>;
launcher_num_act_in(kernel, loc_iter, hash, launcher_num_act_in(kernel, loc_iter, hash,
indices.data_ptr<int>(), indice_pairs.data_ptr<int>(), indices.data_ptr<int>(), indice_pairs.data_ptr<int>(),
indice_pair_mask[0].data_ptr<uint32_t>(), indice_pair_mask[1].data_ptr<uint32_t>(), indice_pair_mask[0].data_ptr<uint32_t>(), indice_pair_mask[1].data_ptr<uint32_t>(),
indices.dim(0), indice_pairs.dim(2), kv); indices.dim(0), indice_pairs.dim(2), kv);
}}else{{
tv::cuda::Launch lanucher_fill(indice_pair_mask.size(), custream);
lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indice_pair_mask.size());
TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error");
launcher_num_act_in(calc_subm_conv_indices_mask<table_t>, loc_iter, hash,
indices.data_ptr<int>(), indice_pairs.data_ptr<int>(),
indice_pair_mask.data_ptr<uint32_t>(), indices.dim(0), indice_pairs.dim(2), kv);
}}
}}else{{ }}else{{
tv::cuda::Launch lanucher_fill(indice_pair_mask.size(), custream); launcher_num_act_in(calc_subm_conv_indices<table_t>, loc_iter, hash, indices.data_ptr<int>(),
lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indice_pair_mask.size()); indice_pairs.data_ptr<int>(),
TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error"); indice_num_per_loc.data_ptr<int>(), indices.dim(0), indice_pairs.dim(2), kv);
launcher_num_act_in(calc_subm_conv_indices_mask<table_t>, loc_iter, hash,
indices.data_ptr<int>(), indice_pairs.data_ptr<int>(),
indice_pair_mask.data_ptr<uint32_t>(), indices.dim(0), indice_pairs.dim(2), kv);
}} }}
}}else{{
launcher_num_act_in(calc_subm_conv_indices<table_t>, loc_iter, hash, indices.data_ptr<int>(), }});
indice_pairs.data_ptr<int>(), // tv::ssprint("clear hash time", hashdata.dim(0), timer.report() / 1000.0);
indice_num_per_loc.data_ptr<int>(), indices.dim(0), indice_pairs.dim(2), kv);
}}
// tv::ssprint("gem subm conv inds time", timer.report() / 1000.0); // tv::ssprint("gem subm conv inds time", timer.report() / 1000.0);
return indices.dim(0); return indices.dim(0);
""") """)
...@@ -1057,8 +1096,9 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass): ...@@ -1057,8 +1096,9 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass):
stride[i] = 1; stride[i] = 1;
padding[i] = (ksize[i] / 2) * dilation[i]; padding[i] = (ksize[i] / 2) * dilation[i];
}} }}
int kv = tv::arrayops::prod(ksize); int kv = ksize.op<tv::arrayops::prod>();
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<{self.dtype_indices}>::max(),
"kernel volume must smaller than max value of {self.dtype_indices}");
ConvProblem problem(batch_size, 1, 1, input_dims, input_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, input_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem); ConvLocIter loc_iter(problem);
int indices_pair_size = indice_pairs.dim(2); int indices_pair_size = indice_pairs.dim(2);
...@@ -1116,7 +1156,7 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass): ...@@ -1116,7 +1156,7 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass):
f"tv::array<int, {self.ndim}>") f"tv::array<int, {self.ndim}>")
code.arg("transposed", f"bool", "false") code.arg("transposed", f"bool", "false")
code.raw(f""" code.raw(f"""
int kv = tv::arrayops::prod(ksize); int kv = ksize.op<tv::arrayops::prod>();
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation); ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem); ConvLocIter loc_iter(problem);
int indices_pair_size = indice_pairs.dim(2); int indices_pair_size = indice_pairs.dim(2);
...@@ -1125,6 +1165,8 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass): ...@@ -1125,6 +1165,8 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass):
std::unordered_map<{self.dtype_indices}, {self.dtype_indices}> hash; std::unordered_map<{self.dtype_indices}, {self.dtype_indices}> hash;
auto indices_ptr = indices.data_ptr<{self.dtype_indices}>(); auto indices_ptr = indices.data_ptr<{self.dtype_indices}>();
auto out_inds_ptr = out_inds.data_ptr<{self.dtype_indices}>(); auto out_inds_ptr = out_inds.data_ptr<{self.dtype_indices}>();
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<{self.dtype_indices}>::max(),
"kernel volume must smaller than max value of {self.dtype_indices}");
int indice_in_num = indices.dim(0); int indice_in_num = indices.dim(0);
int num_act = 0; int num_act = 0;
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import contextlib import contextlib
from cumm.conv.bases import ConvEnum
from cumm.gemm.core.metaarray import MetaArray, seq from cumm.gemm.core.metaarray import MetaArray, seq
from cumm import dtypes from cumm import dtypes
import pccm import pccm
...@@ -202,14 +201,14 @@ class IndiceMaxPool(pccm.Class): ...@@ -202,14 +201,14 @@ class IndiceMaxPool(pccm.Class):
// if a value is found, other value won't be executed. // if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value; int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures; int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0)); dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0); dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream); launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}); }});
if (!found){{ if (!found){{
int NumFeatures = 16; int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures; int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0)); dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0); dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream); launcher = tv::cuda::Launch(blocks, threads, cudastream);
}} }}
...@@ -244,14 +243,14 @@ class IndiceMaxPool(pccm.Class): ...@@ -244,14 +243,14 @@ class IndiceMaxPool(pccm.Class):
// if a value is found, other value won't be executed. // if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value; int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures; int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0)); dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0); dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream); launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}); }});
if (!found){{ if (!found){{
int NumFeatures = 16; int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures; int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0)); dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0); dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream); launcher = tv::cuda::Launch(blocks, threads, cudastream);
}} }}
...@@ -287,14 +286,14 @@ class IndiceMaxPool(pccm.Class): ...@@ -287,14 +286,14 @@ class IndiceMaxPool(pccm.Class):
// if a value is found, other value won't be executed. // if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value; int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures; int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0)); dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0); dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream); launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}); }});
if (!found){{ if (!found){{
int NumFeatures = 16; int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures; int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0)); dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0); dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream); launcher = tv::cuda::Launch(blocks, threads, cudastream);
}} }}
...@@ -331,14 +330,14 @@ class IndiceMaxPool(pccm.Class): ...@@ -331,14 +330,14 @@ class IndiceMaxPool(pccm.Class):
// if a value is found, other value won't be executed. // if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value; int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures; int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0)); dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0); dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream); launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}); }});
if (!found){{ if (!found){{
int NumFeatures = 16; int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures; int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0)); dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0); dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream); launcher = tv::cuda::Launch(blocks, threads, cudastream);
}} }}
......
...@@ -126,6 +126,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -126,6 +126,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
super().__init__() super().__init__()
self.add_dependency(TensorView, TensorViewHashKernel) self.add_dependency(TensorView, TensorViewHashKernel)
self.add_param_class("layout_ns", layout, "Layout") self.add_param_class("layout_ns", layout, "Layout")
self.add_include("tensorview/hash/ops.h")
self.dtype = dtype self.dtype = dtype
self.ndim = ndim self.ndim = ndim
self.zyx = zyx self.zyx = zyx
...@@ -447,7 +448,7 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -447,7 +448,7 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
TV_ASSERT_RT_ERR(point_indice_data.dim(0) >= points.dim(0), "point_indice_data too small") TV_ASSERT_RT_ERR(point_indice_data.dim(0) >= points.dim(0), "point_indice_data too small")
num_per_voxel.zero_(ctx); num_per_voxel.zero_(ctx);
table_t hash = table_t(hashdata.data_ptr<pair_t>(), expected_hash_data_num); table_t hash = table_t(hashdata.data_ptr<pair_t>(), expected_hash_data_num);
hash.clear(custream); tv::hash::clear_map(hash, custream);
auto launcher = tv::cuda::Launch(points.dim(0), custream); auto launcher = tv::cuda::Launch(points.dim(0), custream);
launcher(kernel::build_hash_table<table_t>, hash, points.data_ptr<const {self.dtype}>(), launcher(kernel::build_hash_table<table_t>, hash, points.data_ptr<const {self.dtype}>(),
point_indice_data.data_ptr<int64_t>(), point_indice_data.data_ptr<int64_t>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment