Commit 7af751dc authored by yan.yan's avatar yan.yan
Browse files

sync

parent 647927ce
# Changelog # Changelog
## [2.1.22] - 2022-4-14
### Added
- add full nvrtc support
- add support for large spatial shape and batch size. if detect large shape, we use int64 instead of int32 when hashing.
## [2.1.21] - 2021-12-9 ## [2.1.21] - 2021-12-9
### Added ### Added
- add sm_37 - add sm_37
......
...@@ -56,7 +56,7 @@ def main(): ...@@ -56,7 +56,7 @@ def main():
is_empty = table.insert_exist_keys(keys, values) is_empty = table.insert_exist_keys(keys, values)
ks, vs, cnt = table.items() ks, vs, cnt = table.items()
cnt_item = cnt.item() cnt_item = cnt.item()
print(cnt, ks[:cnt_item], vs[:cnt_item]) print(cnt, ks[:cnt_item], vs[:cnt_item], is_empty.dtype)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -12,28 +12,50 @@ ...@@ -12,28 +12,50 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from enum import Enum import contextlib
from cumm import tensorview as tv
from typing import Dict, List, Set, Tuple, Union
from spconv.core_cc.cumm.gemm.main import GemmAlgoDesp, GemmMainUnitTest, GemmParams
from spconv.core_cc.cumm.conv.main import ConvAlgoDesp, ConvMainUnitTest, ConvParams
from cumm.conv.bases import ConvLayout, ConvLayoutType, ConvOpType
from cumm.gemm.algospec.core import GemmAlgo, ShuffleStrideType, get_min_arch_of_algo_str, get_available_algo_str_from_arch
from cumm.gemm.codeops import group_by, div_up
from spconv.constants import NDIM_DONT_CARE, SPCONV_BWD_SPLITK
from typing import Optional
import time import time
from enum import Enum
from threading import Lock from threading import Lock
import contextlib from typing import Dict, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
from spconv.core import ConvAlgo, AlgoHint from cumm import tensorview as tv
from cumm.conv.bases import ConvLayout, ConvLayoutType, ConvOpType
from cumm.conv.kernel import ConvKernel
from cumm.gemm.kernel import GemmKernel
from cumm.gemm.algospec.core import (GemmAlgo, ShuffleStrideType,
get_available_algo_str_from_arch,
get_min_arch_of_algo_str)
from cumm.gemm.codeops import div_up, group_by
from cumm.nvrtc import CummNVRTCModule, get_cudadevrt_path
from cumm.tensorview.gemm import ConvAlgoDesp
from cumm.tensorview.gemm import ConvOpType as ConvOpTypeCpp
from cumm.tensorview.gemm import ConvParams, GemmAlgoDesp, GemmParams
from cumm import dtypes
from spconv.constants import (NDIM_DONT_CARE, SPCONV_BWD_SPLITK,
SPCONV_NVRTC_MODE, SPCONV_DEBUG_NVRTC_KERNELS)
from spconv.core import ALL_IMPGEMM_PARAMS, AlgoHint, ConvAlgo
from spconv.core_cc.cumm.conv.main import ConvMainUnitTest
from spconv.core_cc.cumm.gemm.main import GemmMainUnitTest
from spconv.cppconstants import COMPILED_CUDA_ARCHS
from cumm.tensorview.gemm import NVRTCParams
from spconv.tools import CUDAKernelTimer from spconv.tools import CUDAKernelTimer
from cumm.gemm.constants import NVRTCConstants, NVRTCMode
from spconv import algocore
from cumm.conv.main import gen_gemm_kernels as gen_conv_kernels
from cumm.gemm.main import gen_gemm_kernels
ALL_ALGO_DESPS = GemmMainUnitTest.get_all_algo_desp() ALL_ALGO_DESPS = GemmMainUnitTest.get_all_algo_desp()
ALL_CONV_ALGO_DESPS = ConvMainUnitTest.get_all_conv_algo_desp() ALL_CONV_ALGO_DESPS = ConvMainUnitTest.get_all_conv_algo_desp()
_GEMM_STATIC_KEY = Tuple[bool, bool, bool, int, int, int, str, str] _GEMM_STATIC_KEY = Tuple[bool, bool, bool, int, int, int, str, str]
class SimpleGemmAlgoMeta: class SimpleGemmAlgoMeta:
def __init__(self, tile_ms: List[int], tile_ns: List[int], def __init__(self, tile_ms: List[int], tile_ns: List[int],
tile_ks: List[int], tile_ks: List[int],
...@@ -45,22 +67,68 @@ class SimpleGemmAlgoMeta: ...@@ -45,22 +67,68 @@ class SimpleGemmAlgoMeta:
class BestAlgoByProfile: class BestAlgoByProfile:
def __init__(self, algo_desp: GemmAlgoDesp, splitk: int = 1) -> None: def __init__(self, algo_desp: GemmAlgoDesp, arch: Tuple[int, int], splitk: int = 1) -> None:
self.algo_desp = algo_desp self.algo_desp = algo_desp
self.splitk = splitk self.splitk = splitk
self.arch = arch
class BestConvAlgoByProfile: class BestConvAlgoByProfile:
def __init__(self, algo_desp: ConvAlgoDesp, splitk: int = 1) -> None: def __init__(self, algo_desp: ConvAlgoDesp, arch: Tuple[int, int], splitk: int = 1) -> None:
self.algo_desp = algo_desp self.algo_desp = algo_desp
self.splitk = splitk self.splitk = splitk
self.arch = arch
def _get_nvrtc_params(mod: CummNVRTCModule, ker: Union[GemmKernel, ConvKernel], kernel_name: str):
nvrtc_mode = SPCONV_NVRTC_MODE
nvrtc_params = tv.gemm.NVRTCParams()
nvrtc_params.cumodule = mod.get_cpp_object()
nvrtc_params.mode = nvrtc_mode.value
nvrtc_params.num_threads = ker.num_threads
nvrtc_params.smem_size = ker.smem_size
ns = ker.namespace
if nvrtc_mode == NVRTCMode.DynamicParallism:
nvrtc_params.kernel_name = mod.get_lowered_name(
f"{ns}::nvrtc_kernel")
elif nvrtc_mode == NVRTCMode.KernelAndCPU:
nvrtc_params.kernel_name = mod.get_lowered_name(f"{ns}::{kernel_name}")
nvrtc_params.init_kernel_name = mod.get_lowered_name(
f"{ns}::nvrtc_kernel_cpu_out")
nvrtc_params.param_size = mod.const_values[
f"{ns}::{NVRTCConstants.SIZEOF_KEY}"]
nvrtc_params.param_storage = tv.empty([nvrtc_params.param_size],
tv.uint8, 0)
nvrtc_params.param_storage_cpu = tv.empty(
[nvrtc_params.param_size], tv.uint8, -1, pinned=True)
elif nvrtc_mode == NVRTCMode.Direct:
nvrtc_params.kernel_name = mod.get_lowered_name(f"{ns}::{kernel_name}")
elif nvrtc_mode == NVRTCMode.ConstantMemory:
nvrtc_params.kernel_name = mod.get_lowered_name(f"{ns}::{kernel_name}")
nvrtc_params.init_kernel_name = mod.get_lowered_name(
f"{ns}::nvrtc_kernel_cpu_out")
nvrtc_params.param_size = mod.const_values[
f"{ns}::{NVRTCConstants.SIZEOF_KEY}"]
nvrtc_params.constant_name = mod.get_lowered_name(
f"&{ns}::{NVRTCConstants.CONSTANT_PARAM_KEY}")
nvrtc_params.param_storage = tv.empty([nvrtc_params.param_size],
tv.uint8, 0)
else:
raise NotImplementedError
return nvrtc_params
class SimpleGemm: class SimpleGemm:
def __init__(self, desps: List[GemmAlgoDesp]) -> None: def __init__(self, prebuilt_desps: List[GemmAlgoDesp]) -> None:
self.desps = desps all_desps = [algocore.get_conv_algo_desp_from_param(p) for p in ALL_IMPGEMM_PARAMS]
self.prebuilt_desps = prebuilt_desps
self.prebuilt_desp_names = {str(d) for d in prebuilt_desps}
if SPCONV_DEBUG_NVRTC_KERNELS:
self.prebuilt_desp_names.clear()
self.lock = Lock() self.lock = Lock()
self.static_key_to_desps = group_by(self.get_static_key, desps) self.static_key_to_desps = group_by(self.get_static_key, all_desps)
self.static_key_to_meta: Dict[_GEMM_STATIC_KEY, self.static_key_to_meta: Dict[_GEMM_STATIC_KEY,
SimpleGemmAlgoMeta] = {} SimpleGemmAlgoMeta] = {}
for k, static_desps in self.static_key_to_desps.items(): for k, static_desps in self.static_key_to_desps.items():
...@@ -94,15 +162,44 @@ class SimpleGemm: ...@@ -94,15 +162,44 @@ class SimpleGemm:
self.mn_cache: Dict[Tuple[int, int, int, int, int], self.mn_cache: Dict[Tuple[int, int, int, int, int],
BestAlgoByProfile] = {} # for backward weight BestAlgoByProfile] = {} # for backward weight
self._nvrtc_caches: Dict[Tuple[str, Tuple[int, int]], NVRTCParams] = {}
@staticmethod @staticmethod
def get_static_key(d: GemmAlgoDesp) -> _GEMM_STATIC_KEY: def get_static_key(d: GemmAlgoDesp) -> _GEMM_STATIC_KEY:
return (d.trans_a, d.trans_b, d.trans_c, d.dtype_a, d.dtype_b, return (d.trans_a, d.trans_b, d.trans_c, d.dtype_a, d.dtype_b,
d.dtype_c, d.shuffle_type, d.algo) d.dtype_c, d.shuffle_type.value, d.algo)
def device_synchronize(self): def device_synchronize(self):
return GemmMainUnitTest.device_synchronize() return GemmMainUnitTest.device_synchronize()
def _compile_nvrtc_module(self, desp: GemmAlgoDesp):
params = algocore.get_gemm_param_from_desp(desp)
kernel = gen_gemm_kernels(params, SPCONV_NVRTC_MODE)
kernel.namespace = "spconv"
custom_names = []
if SPCONV_NVRTC_MODE == NVRTCMode.ConstantMemory:
custom_names = [f"&{kernel.namespace}::{NVRTCConstants.CONSTANT_PARAM_KEY}"]
cudadevrt = ""
if SPCONV_NVRTC_MODE == NVRTCMode.DynamicParallism:
cudadevrt_p = get_cudadevrt_path()
assert cudadevrt_p is not None, "DynamicParallism must have cudadevrt"
cudadevrt = str(cudadevrt_p)
mod = CummNVRTCModule([kernel],
cudadevrt_path=cudadevrt,
verbose=False,
custom_names=custom_names)
mod.load()
return mod, kernel
def _cached_get_nvrtc_params(self, desp: GemmAlgoDesp, arch: Tuple[int, int]):
key = (str(desp), arch)
if key in self._nvrtc_caches:
return self._nvrtc_caches[key]
mod, ker = self._compile_nvrtc_module(desp)
nvrtc_params = _get_nvrtc_params(mod, ker, "gemm_kernel")
self._nvrtc_caches[key] = nvrtc_params
return nvrtc_params
def get_all_available( def get_all_available(
self, self,
a: tv.Tensor, a: tv.Tensor,
...@@ -135,6 +232,11 @@ class SimpleGemm: ...@@ -135,6 +232,11 @@ class SimpleGemm:
ldb = b.dim(1) ldb = b.dim(1)
ldc = c.dim(1) ldc = c.dim(1)
if desp.supported_ldx(lda, ldb, ldc): if desp.supported_ldx(lda, ldb, ldc):
if arch not in COMPILED_CUDA_ARCHS:
desp = desp.copy()
desp.is_nvrtc = True
if SPCONV_DEBUG_NVRTC_KERNELS:
desp.is_nvrtc = True
finally_algos.append(desp) finally_algos.append(desp)
return finally_algos return finally_algos
...@@ -334,6 +436,8 @@ class SimpleGemm: ...@@ -334,6 +436,8 @@ class SimpleGemm:
if desp.split_k_serial and hint & AlgoHint.BackwardWeight.value: if desp.split_k_serial and hint & AlgoHint.BackwardWeight.value:
split_k_slices = max(min(32, k // 128), 1) split_k_slices = max(min(32, k // 128), 1)
params = GemmParams() params = GemmParams()
if desp.is_nvrtc and str(desp) not in self.prebuilt_desp_names:
params.nvrtc_params = self._cached_get_nvrtc_params(desp, arch)
params.a = a params.a = a
params.b = b params.b = b
params.c = c_ params.c = c_
...@@ -361,7 +465,7 @@ class SimpleGemm: ...@@ -361,7 +465,7 @@ class SimpleGemm:
times.append(np.mean(this_times[1:])) times.append(np.mean(this_times[1:]))
spk_speeds.append(times[-1]) spk_speeds.append(times[-1])
all_profile_res.append(BestAlgoByProfile(desp, splitk=spk)) all_profile_res.append(BestAlgoByProfile(desp, arch, splitk=spk))
min_time = 1000 min_time = 1000
min_idx = -1 min_idx = -1
...@@ -421,6 +525,9 @@ class SimpleGemm: ...@@ -421,6 +525,9 @@ class SimpleGemm:
if profile_res.splitk > 1: if profile_res.splitk > 1:
split_k_slices = profile_res.splitk split_k_slices = profile_res.splitk
params = GemmParams() params = GemmParams()
if algo_desp.is_nvrtc and str(algo_desp) not in self.prebuilt_desp_names:
params.nvrtc_params = self._cached_get_nvrtc_params(algo_desp, profile_res.arch)
params.a = a params.a = a
params.b = b params.b = b
params.c = c params.c = c
...@@ -461,11 +568,14 @@ _CONV_STATIC_KEY = Tuple[int, int, int, int, int, int, int, int, int, str, int] ...@@ -461,11 +568,14 @@ _CONV_STATIC_KEY = Tuple[int, int, int, int, int, int, int, int, int, str, int]
class SimpleConv: class SimpleConv:
def __init__(self, desps: List[ConvAlgoDesp]) -> None: def __init__(self, prebuilt_desps: List[ConvAlgoDesp]) -> None:
self.desps = desps all_desps = [algocore.get_conv_algo_desp_from_param(p) for p in ALL_IMPGEMM_PARAMS]
self.prebuilt_desps = prebuilt_desps
self.prebuilt_desp_names = {str(d) for d in prebuilt_desps}
self.prebuilt_desp_names.clear()
self.lock = Lock() self.lock = Lock()
self.static_key_to_desps = group_by(self.get_static_key, desps) self.static_key_to_desps = group_by(self.get_static_key, all_desps)
self.static_key_to_meta: Dict[_CONV_STATIC_KEY, self.static_key_to_meta: Dict[_CONV_STATIC_KEY,
SimpleGemmAlgoMeta] = {} SimpleGemmAlgoMeta] = {}
for k, static_desps in self.static_key_to_desps.items(): for k, static_desps in self.static_key_to_desps.items():
...@@ -500,28 +610,36 @@ class SimpleConv: ...@@ -500,28 +610,36 @@ class SimpleConv:
self.kc_wgrad_cache: Dict[Tuple[int, int, int, int, int, int, int, self.kc_wgrad_cache: Dict[Tuple[int, int, int, int, int, int, int,
int], BestConvAlgoByProfile] = { int], BestConvAlgoByProfile] = {
} # for backward weight } # for backward weight
self._nvrtc_caches: Dict[Tuple[str, Tuple[int, int]], NVRTCParams] = {}
@staticmethod @staticmethod
def get_static_key(d: ConvAlgoDesp) -> _CONV_STATIC_KEY: def get_static_key(d: ConvAlgoDesp) -> _CONV_STATIC_KEY:
return (d.layout_i, d.layout_w, d.layout_o, d.interleave_i, return (d.layout_i.value, d.layout_w.value, d.layout_o.value,
d.interleave_w, d.interleave_o, d.dtype_input, d.dtype_weight, d.interleave_i, d.interleave_w, d.interleave_o, d.dtype_input,
d.dtype_output, d.algo, d.op_type) d.dtype_weight, d.dtype_output, d.algo, d.op_type.value)
def device_synchronize(self): def device_synchronize(self):
return GemmMainUnitTest.device_synchronize() return GemmMainUnitTest.device_synchronize()
def get_all_available(self, inp: tv.Tensor, weight: tv.Tensor, def get_all_available(self,
out: tv.Tensor, layout_i: ConvLayout, inp: tv.Tensor,
layout_w: ConvLayout, layout_o: ConvLayout, weight: tv.Tensor,
arch: Tuple[int, int], op_type: ConvOpType, out: tv.Tensor,
mask_width: int, fp32_accum: Optional[bool] = None): layout_i: ConvLayout,
layout_w: ConvLayout,
layout_o: ConvLayout,
arch: Tuple[int, int],
op_type: ConvOpType,
mask_width: int,
fp32_accum: Optional[bool] = None):
avail_algos = get_available_algo_str_from_arch(arch) avail_algos = get_available_algo_str_from_arch(arch)
finally_algos: List[ConvAlgoDesp] = [] finally_algos: List[ConvAlgoDesp] = []
is_fp16 = inp.dtype == tv.float16 and weight.dtype == tv.float16 and out.dtype == tv.float16 is_fp16 = inp.dtype == tv.float16 and weight.dtype == tv.float16 and out.dtype == tv.float16
use_f32_as_accum = False use_f32_as_accum = False
kv = int(np.prod(weight.shape[1:-1])) kv = int(np.prod(weight.shape[1:-1]))
# for 3d conv, if reduce axis is too large, may cause nan during # for 3d conv, if reduce axis is too large, may cause nan during
# forward. # forward.
if is_fp16: if is_fp16:
if fp32_accum is None: if fp32_accum is None:
...@@ -551,7 +669,7 @@ class SimpleConv: ...@@ -551,7 +669,7 @@ class SimpleConv:
if use_f32_as_accum: if use_f32_as_accum:
if desp.dacc == tv.float16: if desp.dacc == tv.float16:
continue continue
ldi = inp.dim(-1) ldi = inp.dim(-1)
ldw = weight.dim(-1) ldw = weight.dim(-1)
ldo = out.dim(-1) ldo = out.dim(-1)
...@@ -560,6 +678,11 @@ class SimpleConv: ...@@ -560,6 +678,11 @@ class SimpleConv:
assert mask_width > 0 assert mask_width > 0
mask_width_valid = mask_width % desp.tile_shape[2] == 0 mask_width_valid = mask_width % desp.tile_shape[2] == 0
if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid: if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid:
if arch not in COMPILED_CUDA_ARCHS:
desp = desp.copy()
desp.is_nvrtc = True
if SPCONV_DEBUG_NVRTC_KERNELS:
desp.is_nvrtc = True
finally_algos.append(desp) finally_algos.append(desp)
return finally_algos return finally_algos
...@@ -592,6 +715,34 @@ class SimpleConv: ...@@ -592,6 +715,34 @@ class SimpleConv:
return desp.query_conv_workspace_size(mnk[0], mnk[1], mnk[2], splitk, return desp.query_conv_workspace_size(mnk[0], mnk[1], mnk[2], splitk,
kv) kv)
def _compile_nvrtc_module(self, desp: ConvAlgoDesp):
params = algocore.get_conv_param_from_desp(desp)
kernel = gen_conv_kernels(params, SPCONV_NVRTC_MODE)
kernel.namespace = "spconv"
custom_names = []
if SPCONV_NVRTC_MODE == NVRTCMode.ConstantMemory:
custom_names = [f"&{kernel.namespace}::{NVRTCConstants.CONSTANT_PARAM_KEY}"]
cudadevrt = ""
if SPCONV_NVRTC_MODE == NVRTCMode.DynamicParallism:
cudadevrt_p = get_cudadevrt_path()
assert cudadevrt_p is not None, "DynamicParallism must have cudadevrt"
cudadevrt = str(cudadevrt_p)
mod = CummNVRTCModule([kernel],
cudadevrt_path=cudadevrt,
verbose=False,
custom_names=custom_names)
mod.load()
return mod, kernel
def _cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int, int]):
key = (str(desp), arch)
if key in self._nvrtc_caches:
return self._nvrtc_caches[key]
mod, ker = self._compile_nvrtc_module(desp)
nvrtc_params = _get_nvrtc_params(mod, ker, "conv_kernel")
self._nvrtc_caches[key] = nvrtc_params
return nvrtc_params
def tune_and_cache(self, def tune_and_cache(self,
op_type: ConvOpType, op_type: ConvOpType,
inp: tv.Tensor, inp: tv.Tensor,
...@@ -613,7 +764,7 @@ class SimpleConv: ...@@ -613,7 +764,7 @@ class SimpleConv:
stream: int = 0, stream: int = 0,
fp32_accum: Optional[bool] = None): fp32_accum: Optional[bool] = None):
avail = self.get_all_available(inp, weight, output, layout_i, layout_w, avail = self.get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, arch, op_type, mask_width, layout_o, arch, op_type, mask_width,
fp32_accum) fp32_accum)
inp = inp.clone() inp = inp.clone()
weight = weight.clone() weight = weight.clone()
...@@ -626,7 +777,10 @@ class SimpleConv: ...@@ -626,7 +777,10 @@ class SimpleConv:
all_profile_res: List[BestConvAlgoByProfile] = [] all_profile_res: List[BestConvAlgoByProfile] = []
for desp in avail: for desp in avail:
# for sparse conv, ndim isn't used, so we just provide a constant value. # for sparse conv, ndim isn't used, so we just provide a constant value.
params = ConvParams(NDIM_DONT_CARE, op_type.value) params = ConvParams(NDIM_DONT_CARE, ConvOpTypeCpp(op_type.value))
if desp.is_nvrtc and str(desp) not in self.prebuilt_desp_names:
params.nvrtc_params = self._cached_get_nvrtc_params(desp, arch)
params.conv_algo_desp = desp params.conv_algo_desp = desp
params.input = inp params.input = inp
params.weight = weight.view([channel_k, -1, channel_c]) params.weight = weight.view([channel_k, -1, channel_c])
...@@ -657,13 +811,16 @@ class SimpleConv: ...@@ -657,13 +811,16 @@ class SimpleConv:
GemmMainUnitTest.stream_synchronize(stream) GemmMainUnitTest.stream_synchronize(stream)
t = time.time() t = time.time()
params.split_k_slices = spk params.split_k_slices = spk
ConvMainUnitTest.implicit_gemm2(params) if desp.is_nvrtc and str(desp) not in self.prebuilt_desp_names:
tv.gemm.run_nvrtc_conv_kernel(params)
else:
ConvMainUnitTest.implicit_gemm2(params)
GemmMainUnitTest.stream_synchronize(stream) GemmMainUnitTest.stream_synchronize(stream)
this_times.append(time.time() - t) this_times.append(time.time() - t)
times.append(np.mean(this_times[1:])) times.append(np.mean(this_times[1:]))
spk_speeds.append(times[-1]) spk_speeds.append(times[-1])
all_profile_res.append(BestConvAlgoByProfile(desp, splitk=spk)) all_profile_res.append(BestConvAlgoByProfile(desp, arch, splitk=spk))
if not all_profile_res: if not all_profile_res:
raise ValueError("can't find suitable algorithm for", op_type) raise ValueError("can't find suitable algorithm for", op_type)
min_time = 1000 min_time = 1000
...@@ -720,7 +877,9 @@ class SimpleConv: ...@@ -720,7 +877,9 @@ class SimpleConv:
op_type_value = op_type op_type_value = op_type
else: else:
op_type_value = op_type.value op_type_value = op_type.value
params = ConvParams(NDIM_DONT_CARE, op_type_value) params = ConvParams(NDIM_DONT_CARE, ConvOpTypeCpp(op_type_value))
if algo_desp.is_nvrtc and str(algo_desp) not in self.prebuilt_desp_names:
params.nvrtc_params = self._cached_get_nvrtc_params(algo_desp, profile_res.arch)
params.conv_algo_desp = profile_res.algo_desp params.conv_algo_desp = profile_res.algo_desp
params.input = inp params.input = inp
params.verbose = verbose params.verbose = verbose
......
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Set, Tuple, Union
from cumm.conv.bases import ConvLayout, ConvLayoutType, ConvOpType
from cumm.gemm.algospec.core import (GemmAlgo, ShuffleStrideType)
from cumm.tensorview.gemm import ConvAlgoDesp
from cumm.tensorview.gemm import ConvIterAlgo as ConvIterAlgoCpp
from cumm.tensorview.gemm import ConvOpType as ConvOpTypeCpp
from cumm.tensorview.gemm import ConvLayoutType as ConvLayoutTypeCpp
from cumm.tensorview.gemm import ShuffleStrideType as ShuffleStrideTypeCpp
from cumm.tensorview.gemm import ConvParams, GemmAlgoDesp, GemmParams
from cumm.gemm.main import GemmAlgoParams
from cumm.conv.main import ConvAlgoParams, ConvIterAlgo
from cumm import dtypes
from cumm.conv.bases import (NCHW, NHWC, ConvIterAlgo, ConvLayout,
ConvLayoutType, ConvMode, ConvOpType)
from cumm.gemm.core import MetaArray
from cumm.gemm.algospec import TensorOp
def _assign_gemm_desp_props(desp: Union[ConvAlgoDesp, GemmAlgoDesp],
p: Union[GemmAlgoParams, ConvAlgoParams]):
desp.dtype_a = p.dtype_a.tv_dtype
desp.dtype_b = p.dtype_a.tv_dtype
desp.dtype_c = p.dtype_a.tv_dtype
desp.dacc = p.dtype_acc.tv_dtype
desp.dcomp = p.dtype_comp.tv_dtype
desp.trans_a = p.trans_a
desp.trans_b = p.trans_b
desp.trans_c = p.trans_c
desp.tile_shape = (p.ts[0], p.ts[1], p.ts[2])
desp.warp_tile_shape = (p.wts[0], p.wts[1], p.wts[2])
if p.tensorop is not None:
desp.tensorop = (p.tensorop[0], p.tensorop[1], p.tensorop[2])
desp.num_stage = p.num_stage
desp.algo = p.algo.value
desp.split_k_serial = p.splitk_serial
desp.split_k_parallel = p.splitk_parallel
desp.shuffle_type = ShuffleStrideTypeCpp(p.shuffle_stride.value)
desp.access_per_vector = p.access_per_vector
desp.is_nvrtc = p.is_nvrtc
def get_gemm_algo_desp_from_param(p: GemmAlgoParams):
desp = GemmAlgoDesp()
_assign_gemm_desp_props(desp, p)
return desp
def get_conv_algo_desp_from_param(p: ConvAlgoParams):
desp = ConvAlgoDesp(p.ndim, ConvOpTypeCpp(p.op_type.value))
_assign_gemm_desp_props(desp, p)
# conv attrs
desp.ndim = p.ndim
desp.op_type = ConvOpTypeCpp(p.op_type.value)
desp.iter_algo = ConvIterAlgoCpp(p.iter_algo.value)
desp.layout_i = ConvLayoutTypeCpp(p.layout_desp_input.layout_type.value)
desp.layout_w = ConvLayoutTypeCpp(p.layout_desp_weight.layout_type.value)
desp.layout_o = ConvLayoutTypeCpp(p.layout_desp_output.layout_type.value)
desp.interleave_i = p.layout_desp_input.interleave
desp.interleave_w = p.layout_desp_weight.interleave
desp.interleave_o = p.layout_desp_output.interleave
desp.mask_sparse = p.mask_sparse
desp.increment_k_first = p.increment_k_first
return desp
def _assign_gemm_params(desp: Union[ConvAlgoDesp, GemmAlgoDesp],
p: Union[GemmAlgoParams, ConvAlgoParams]):
p.dtype_a = dtypes.get_dtype_from_tvdtype(desp.dtype_a)
p.dtype_b = dtypes.get_dtype_from_tvdtype(desp.dtype_b)
p.dtype_c = dtypes.get_dtype_from_tvdtype(desp.dtype_c)
p.dtype_acc = dtypes.get_dtype_from_tvdtype(desp.dacc)
p.dtype_comp = dtypes.get_dtype_from_tvdtype(desp.dcomp)
p.trans_a = desp.trans_a
p.trans_b = desp.trans_b
p.trans_c = desp.trans_c
p.ts = MetaArray(*desp.tile_shape)
p.wts = MetaArray(*desp.warp_tile_shape)
if desp.tensorop[0] > 0:
p.tensorop = TensorOp(
(desp.tensorop[0], desp.tensorop[1], desp.tensorop[2]))
p.num_stage = desp.num_stage
p.algo = GemmAlgo(desp.algo)
p.splitk_serial = desp.split_k_serial
p.splitk_parallel = desp.split_k_parallel
p.shuffle_stride = ShuffleStrideType(desp.shuffle_type.value)
p.access_per_vector = desp.access_per_vector
p.is_nvrtc = desp.is_nvrtc
def get_gemm_param_from_desp(desp: GemmAlgoDesp):
p = GemmAlgoParams((0, 0, 0), (0, 0, 0), 0, "s8,s8,s8,s8,s8", False, False,
False, GemmAlgo.Simt)
_assign_gemm_params(desp, p)
return p
def get_conv_param_from_desp(desp: ConvAlgoDesp):
p = ConvAlgoParams(desp.ndim, ConvOpType.kForward, ConvIterAlgo.Optimized,
(0, 0, 0), (0, 0, 0), 0, "s8,s8,s8,s8,s8", NHWC, NHWC, NHWC,
GemmAlgo.Simt)
_assign_gemm_params(desp, p)
# conv attrs
p.ndim = desp.ndim
p.op_type = ConvOpType(desp.op_type.value)
p.iter_algo = ConvIterAlgo(desp.iter_algo.value)
p.layout_desp_input = ConvLayout(ConvLayoutType(desp.layout_i.value),
desp.interleave_i)
p.layout_desp_weight = ConvLayout(ConvLayoutType(desp.layout_w.value),
desp.interleave_w)
p.layout_desp_output = ConvLayout(ConvLayoutType(desp.layout_o.value),
desp.interleave_o)
p.mask_sparse = desp.mask_sparse
p.increment_k_first = desp.increment_k_first
return p
...@@ -29,21 +29,20 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable( ...@@ -29,21 +29,20 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
from cumm.common import CompileInfo from cumm.common import CompileInfo
from spconv.csrc.sparse.all import SpconvOps from spconv.csrc.sparse.all import SpconvOps
from spconv.csrc.sparse.alloc import ExternalAllocator
from spconv.csrc.utils import BoxOps from spconv.csrc.utils import BoxOps
from spconv.csrc.hash.core import HashTable from spconv.csrc.hash.core import HashTable
all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS
cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
SHUFFLE_TURING_PARAMS) cu = GemmMainUnitTest(all_shuffle)
cu.namespace = "cumm.gemm.main" cu.namespace = "cumm.gemm.main"
convcu = ConvMainUnitTest(IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS +
IMPLGEMM_TURING_PARAMS) IMPLGEMM_TURING_PARAMS)
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
convcu = ConvMainUnitTest(all_imp)
convcu.namespace = "cumm.conv.main" convcu.namespace = "cumm.conv.main"
objects_folder = None pccm.builder.build_pybind([cu, convcu, SpconvOps(), BoxOps(), HashTable(), CompileInfo(), ExternalAllocator()],
if InWindows:
# windows have command line limit, so we use objects_folder to reduce command size.
objects_folder = "objects"
pccm.builder.build_pybind([cu, convcu, SpconvOps(), BoxOps(), HashTable(), CompileInfo()],
PACKAGE_ROOT / "core_cc", PACKAGE_ROOT / "core_cc",
namespace_root=PACKAGE_ROOT, namespace_root=PACKAGE_ROOT,
objects_folder=objects_folder,
load_library=False) load_library=False)
...@@ -16,6 +16,8 @@ import os ...@@ -16,6 +16,8 @@ import os
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from pccm.utils import project_is_editable, project_is_installed from pccm.utils import project_is_editable, project_is_installed
from cumm.gemm.constants import NVRTCMode
import enum
PACKAGE_NAME = "spconv" PACKAGE_NAME = "spconv"
PACKAGE_ROOT = Path(__file__).parent.resolve() PACKAGE_ROOT = Path(__file__).parent.resolve()
...@@ -43,3 +45,21 @@ else: ...@@ -43,3 +45,21 @@ else:
# for f16 backward weight, larger splitk, larger compute error. # for f16 backward weight, larger splitk, larger compute error.
# so we use this env to control maximum splitk. # so we use this env to control maximum splitk.
SPCONV_BWD_SPLITK = list(map(int, os.getenv("SPCONV_BWD_SPLITK", "1,2,4,8,16,32,64").split(","))) SPCONV_BWD_SPLITK = list(map(int, os.getenv("SPCONV_BWD_SPLITK", "1,2,4,8,16,32,64").split(",")))
SPCONV_NVRTC_MODE = NVRTCMode.ConstantMemory
SPCONV_DEBUG_NVRTC_KERNELS = False
class SpconvAllocatorKeys:
Pair = "Pair"
IndiceNumPerLoc = "IndiceNumPerLoc"
PairMask = "PairMask"
MaskArgSort = "MaskArgSort"
OutIndices = "OutIndices"
PairFwd = "PairFwd"
# PairMaskFwd = "PairMaskFwd"
PairMaskBwd = "PairMaskBwd"
# MaskArgSortFwd = "MaskArgSortFwd"
MaskArgSortBwd = "MaskArgSortBwd"
OutFeatures = "OutFeatures"
...@@ -15,17 +15,17 @@ from enum import Enum ...@@ -15,17 +15,17 @@ from enum import Enum
from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgoParams from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgoParams
from cumm.gemm import kernel from cumm.gemm import kernel
from typing import List from typing import List
from cumm.gemm.algospec.core import TensorOpParams from cumm.gemm.algospec.core import TensorOp
from cumm.conv.main import gen_gemm_params as gen_conv_params, ConvFwdAndBwdInput, ConvBwdWeight, ConvIterAlgo, GemmAlgo from cumm.conv.main import gen_gemm_params as gen_conv_params, ConvFwdAndBwdInput, ConvBwdWeight, ConvIterAlgo, GemmAlgo
from cumm.conv.bases import (NCHW, NHWC, ConvEnum, ConvIterAlgo, ConvLayout, from cumm.conv.bases import (NCHW, NHWC, ConvIterAlgo, ConvLayout,
ConvLayoutType, ConvMode, ConvOpType) ConvLayoutType, ConvMode, ConvOpType)
from spconv.constants import NDIM_DONT_CARE from spconv.constants import NDIM_DONT_CARE
class ConvAlgo(Enum): class ConvAlgo(Enum):
Native = "Native" Native = 0
MaskImplicitGemm = "MaskImplicitGemm" MaskImplicitGemm = 1
MaskSplitImplicitGemm = "MaskSplitImplicitGemm" MaskSplitImplicitGemm = 2
class AlgoHint(Enum): class AlgoHint(Enum):
...@@ -40,17 +40,17 @@ class AlgoHint(Enum): ...@@ -40,17 +40,17 @@ class AlgoHint(Enum):
SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [ SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [
*gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None), 2, kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True),
*gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None), 2, kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True),
*gen_shuffle_params((128, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], *gen_shuffle_params((128, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"],
"", 2, kernel.GemmAlgo.SimtDP4A, None), "", 2, kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True),
*gen_shuffle_params( *gen_shuffle_params(
(128, 128, 32), (128, 128, 32),
(64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2, (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
kernel.GemmAlgo.SimtDP4A, None), kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True),
*gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.SimtDP4A, None), 2, kernel.GemmAlgo.SimtDP4A, None, is_nvrtc=True),
*gen_shuffle_params((64, 256, 8), (32, 64, 8), ["f32,f32,f32,f32,f32"], *gen_shuffle_params((64, 256, 8), (32, 64, 8), ["f32,f32,f32,f32,f32"],
"f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None),
# *gen_shuffle_params( # *gen_shuffle_params(
...@@ -104,88 +104,88 @@ SHUFFLE_VOLTA_PARAMS: List[GemmAlgoParams] = [ ...@@ -104,88 +104,88 @@ SHUFFLE_VOLTA_PARAMS: List[GemmAlgoParams] = [
*gen_shuffle_params( *gen_shuffle_params(
(64, 64, 32), (64, 64, 32),
(32, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (32, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Volta, TensorOpParams((8, 8, 4))), kernel.GemmAlgo.Volta, TensorOp((8, 8, 4))),
# *gen_shuffle_params( # *gen_shuffle_params(
# (128, 128, 32), # (128, 128, 32),
# (64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, # (64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
# kernel.GemmAlgo.Volta, TensorOpParams((8, 8, 4))), # kernel.GemmAlgo.Volta, TensorOp((8, 8, 4))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 256, 32), (128, 256, 32),
(64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Volta, TensorOpParams((8, 8, 4))), kernel.GemmAlgo.Volta, TensorOp((8, 8, 4))),
*gen_shuffle_params( *gen_shuffle_params(
(256, 128, 32), (256, 128, 32),
(64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Volta, TensorOpParams((8, 8, 4))), kernel.GemmAlgo.Volta, TensorOp((8, 8, 4))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 64, 32), (128, 64, 32),
(64, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (64, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Volta, TensorOpParams((8, 8, 4))), kernel.GemmAlgo.Volta, TensorOp((8, 8, 4))),
*gen_shuffle_params( *gen_shuffle_params(
(64, 128, 32), (64, 128, 32),
(32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Volta, TensorOpParams((8, 8, 4))), kernel.GemmAlgo.Volta, TensorOp((8, 8, 4))),
] ]
# SHUFFLE_VOLTA_PARAMS = [] # SHUFFLE_VOLTA_PARAMS = []
SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
*gen_shuffle_params( *gen_shuffle_params(
(64, 64, 32), (64, 64, 32),
(32, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (32, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 128, 32), (128, 128, 32),
(32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
# *gen_shuffle_params( # *gen_shuffle_params(
# (128, 128, 32), # (128, 128, 32),
# (64, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, # (64, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
# kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), # kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params( *gen_shuffle_params(
(64, 64, 64), (64, 64, 64),
(32, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (32, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params( *gen_shuffle_params(
(64, 128, 64), (64, 128, 64),
(32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 256, 32), (128, 256, 32),
(64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params( *gen_shuffle_params(
(256, 128, 32), (256, 128, 32),
(64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 64, 32), (128, 64, 32),
(64, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (64, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params( *gen_shuffle_params(
(64, 128, 32), (64, 128, 32),
(32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, (32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16)), is_nvrtc=True),
*gen_shuffle_params( *gen_shuffle_params(
(128, 128, 32), (128, 128, 32),
(32, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, (32, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOpParams((8, 8, 16))), TensorOp((8, 8, 16)), is_nvrtc=True),
# *gen_shuffle_params( # *gen_shuffle_params(
# (128, 128, 32), # (128, 128, 32),
# (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2, # (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
# kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), # kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*gen_shuffle_params( *gen_shuffle_params(
(128, 256, 32), (128, 256, 32),
(64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, (64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOpParams((8, 8, 16))), TensorOp((8, 8, 16)), is_nvrtc=True),
*gen_shuffle_params( *gen_shuffle_params(
(256, 128, 32), (256, 128, 32),
(64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, (64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing,
TensorOpParams((8, 8, 16))), TensorOp((8, 8, 16)), is_nvrtc=True),
*gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16)), is_nvrtc=True),
*gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "", *gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "",
2, kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), 2, kernel.GemmAlgo.Turing, TensorOp((8, 8, 16)), is_nvrtc=True),
] ]
# SHUFFLE_TURING_PARAMS = [] # SHUFFLE_TURING_PARAMS = []
...@@ -399,6 +399,34 @@ IMPLGEMM_SIMT_PARAMS = [ ...@@ -399,6 +399,34 @@ IMPLGEMM_SIMT_PARAMS = [
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
] ]
IMPLGEMM_SIMT_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 32, 16), (32, 32, 8),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Simt,
None,
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvBwdWeight, (64, 32, 16), (32, 32, 8),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Simt,
None,
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
]
IMPLGEMM_VOLTA_PARAMS = [ IMPLGEMM_VOLTA_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
...@@ -408,7 +436,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -408,7 +436,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Volta, GemmAlgo.Volta,
TensorOpParams((8, 8, 4)), TensorOp((8, 8, 4)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -420,7 +448,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -420,7 +448,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Volta, GemmAlgo.Volta,
TensorOpParams((8, 8, 4)), TensorOp((8, 8, 4)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=0), access_per_vector=0),
...@@ -432,7 +460,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -432,7 +460,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Volta, GemmAlgo.Volta,
TensorOpParams((8, 8, 4)), TensorOp((8, 8, 4)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -444,7 +472,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -444,7 +472,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Volta, GemmAlgo.Volta,
TensorOpParams((8, 8, 4)), TensorOp((8, 8, 4)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -456,7 +484,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -456,7 +484,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Volta, GemmAlgo.Volta,
TensorOpParams((8, 8, 4)), TensorOp((8, 8, 4)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -468,7 +496,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -468,7 +496,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Volta, GemmAlgo.Volta,
TensorOpParams((8, 8, 4)), TensorOp((8, 8, 4)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=0), access_per_vector=0),
...@@ -480,7 +508,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -480,7 +508,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Volta, GemmAlgo.Volta,
TensorOpParams((8, 8, 4)), TensorOp((8, 8, 4)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -495,7 +523,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -495,7 +523,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=0), access_per_vector=0),
...@@ -507,7 +535,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -507,7 +535,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -519,7 +547,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -519,7 +547,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -531,7 +559,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -531,7 +559,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -543,7 +571,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -543,7 +571,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -555,7 +583,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -555,7 +583,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -567,7 +595,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -567,7 +595,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -579,7 +607,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -579,7 +607,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -591,7 +619,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -591,7 +619,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -603,7 +631,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -603,7 +631,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -615,7 +643,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -615,7 +643,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -628,7 +656,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -628,7 +656,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=0), access_per_vector=0),
...@@ -641,7 +669,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -641,7 +669,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
...@@ -654,12 +682,16 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -654,12 +682,16 @@ IMPLGEMM_TURING_PARAMS = [
NHWC, NHWC,
NHWC, NHWC,
GemmAlgo.Turing, GemmAlgo.Turing,
TensorOpParams((16, 8, 8)), TensorOp((16, 8, 8)),
mask_sparse=True, mask_sparse=True,
increment_k_first=True, increment_k_first=True,
access_per_vector=1), access_per_vector=1),
# *gen_conv_params(ConvBwdWeight, (32, 64, 32), (32, 32, 16), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, "f16,f16,f16,f32,f32", # *gen_conv_params(ConvBwdWeight, (32, 64, 32), (32, 32, 16), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, "f16,f16,f16,f32,f32",
# NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOpParams((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1), # NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOp((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1),
# gen_conv_params(ConvFwdAndBwdInput, ) # gen_conv_params(ConvFwdAndBwdInput, )
] ]
ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS
ALL_IMPGEMM_PARAMS = IMPLGEMM_SIMT_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_VOLTA_PARAMS
...@@ -34,13 +34,15 @@ class SpconvOps: ...@@ -34,13 +34,15 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def generate_conv_inds_stage2(indices: Tensor, hashdata: Tensor, indice_pairs: Tensor, indice_pairs_uniq: Tensor, out_inds: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int: def generate_conv_inds_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
""" """
Args: Args:
indices: indices:
hashdata: hashdata_k:
hashdata_v:
indice_pairs: indice_pairs:
indice_pairs_uniq: indice_pairs_uniq:
indice_pairs_uniq_before_sort:
out_inds: out_inds:
num_out_act: num_out_act:
batch_size: batch_size:
...@@ -74,14 +76,16 @@ class SpconvOps: ...@@ -74,14 +76,16 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int: def generate_conv_inds_mask_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
""" """
Args: Args:
indices: indices:
hashdata: hashdata_k:
hashdata_v:
indice_pairs_fwd: indice_pairs_fwd:
indice_pairs_bwd: indice_pairs_bwd:
indice_pairs_uniq: indice_pairs_uniq:
indice_pairs_uniq_before_sort:
out_inds: out_inds:
mask_fwd: mask_fwd:
mask_bwd: mask_bwd:
...@@ -98,11 +102,12 @@ class SpconvOps: ...@@ -98,11 +102,12 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def generate_subm_conv_inds(indices: Tensor, hashdata: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int: def generate_subm_conv_inds(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
""" """
Args: Args:
indices: indices:
hashdata: hashdata_k:
hashdata_v:
indice_pairs: indice_pairs:
out_inds: out_inds:
indice_num_per_loc: indice_num_per_loc:
...@@ -276,6 +281,18 @@ class SpconvOps: ...@@ -276,6 +281,18 @@ class SpconvOps:
""" """
... ...
@staticmethod @staticmethod
def sort_1d_by_key_split_allocator_v2(data: Tensor, allocator, mask: Tensor, indices: Tensor = Tensor(), stream: int = 0, mask_output: bool = False) -> Tensor:
"""
Args:
data:
allocator:
mask:
indices:
stream:
mask_output:
"""
...
@staticmethod
def count_bits(a: Tensor) -> Tensor: def count_bits(a: Tensor) -> Tensor:
""" """
Args: Args:
...@@ -328,3 +345,51 @@ class SpconvOps: ...@@ -328,3 +345,51 @@ class SpconvOps:
stream_int: stream_int:
""" """
... ...
@staticmethod
def get_int32_max() -> int: ...
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0) -> Tensor:
"""
Args:
allocator:
indices:
batch_size:
input_dims:
algo:
ksize:
stride:
padding:
dilation:
out_padding:
subm:
transposed:
is_train:
stream_int:
"""
...
@staticmethod
def get_indice_pairs(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, stream_int: int = 0) -> None:
"""
Args:
allocator:
indices:
batch_size:
input_dims:
algo:
ksize:
stride:
padding:
dilation:
out_padding:
subm:
transposed:
stream_int:
"""
...
@staticmethod
def test_allocator(allocator) -> None:
"""
Args:
allocator:
"""
...
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
class ExternalAllocator:
def zeros(self, name: str, shape: List[int], dtype: int, device: int) -> Tensor:
"""
Args:
name:
shape:
dtype:
device:
"""
...
def empty(self, name: str, shape: List[int], dtype: int, device: int) -> Tensor:
"""
Args:
name:
shape:
dtype:
device:
"""
...
def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int) -> Tensor:
"""
Args:
name:
shape:
value:
dtype:
device:
"""
...
def full_float(self, name: str, shape: List[int], value: float, dtype: int, device: int) -> Tensor:
"""
Args:
name:
shape:
value:
dtype:
device:
"""
...
def free(self, ten: Tensor) -> None:
"""
Args:
ten:
"""
...
def free_noexcept(self, ten: Tensor) -> None:
"""
Args:
ten:
"""
...
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue from pccm.stubs import EnumValue, EnumClassValue
from ...cumm.gemm.main import GemmAlgoDesp from cumm.tensorview.gemm import ConvParams
from cumm.tensorview import Tensor
from cumm.tensorview import CUDAKernelTimer
class ConvAlgoDesp(GemmAlgoDesp):
ndim: int
op_type: int
iter_algo: int
layout_i: int
layout_w: int
layout_o: int
interleave_i: int
interleave_w: int
interleave_o: int
mask_sparse: bool
increment_k_first: bool
def __init__(self, ndim: int, op_type: int) -> None:
"""
Args:
ndim:
op_type:
"""
...
def __repr__(self) -> str: ...
@staticmethod
def conv_iwo_012_to_abc(op_type: int) -> List[int]:
"""
Args:
op_type:
"""
...
@staticmethod
def gemm_abc_012_to_iwo(op_type: int) -> List[int]:
"""
Args:
op_type:
"""
...
@property
def dtype_input(self) -> int: ...
@property
def dtype_weight(self) -> int: ...
@property
def dtype_output(self) -> int: ...
def supported(self, m: int, n: int, k: int, C: int, K: int, mask_width: int) -> bool:
"""
Args:
m:
n:
k:
C:
K:
mask_width:
"""
...
def query_conv_workspace_size(self, m: int, n: int, k: int, split_k_slices: int, kv: int) -> int:
"""
Args:
m:
n:
k:
split_k_slices:
kv:
"""
...
def supported_ldx_conv(self, ldi: int, ldw: int, ldo: int) -> bool:
"""
Args:
ldi:
ldw:
ldo:
"""
...
class ConvParams:
conv_algo_desp: Any
input: Tensor
weight: Tensor
output: Tensor
split_k_slices: int
padding: List[int]
stride: List[int]
dilation: List[int]
alpha: float
beta: float
mask_width: int
mask_filter: int
reverse_mask: bool
verbose: bool
timer: CUDAKernelTimer
workspace: Tensor = Tensor()
mask: Tensor = Tensor()
mask_argsort: Tensor = Tensor()
indices: Tensor = Tensor()
mask_output: Tensor = Tensor()
stream: int
def __init__(self, ndim: int, op_type: int, timer: CUDAKernelTimer = CUDAKernelTimer(False)) -> None:
"""
Args:
ndim:
op_type:
timer:
"""
...
class ConvMainUnitTest: class ConvMainUnitTest:
@staticmethod @staticmethod
def extract_mnk(op_type: int, N: int, C: int, K: int, kernel_volume: int, in_prod: int, out_prod: int, mask_sparse: bool) -> List[int]: def extract_mnk(op_type: int, N: int, C: int, K: int, kernel_volume: int, in_prod: int, out_prod: int, mask_sparse: bool) -> List[int]:
......
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor from cumm.tensorview.gemm import GemmParams
from cumm.tensorview import CUDAKernelTimer
class GemmAlgoDesp:
dtype_a: int
dtype_b: int
dtype_c: int
tile_shape: Tuple[int, int, int]
warp_tile_shape: Tuple[int, int, int]
num_stage: int
dacc: int
dcomp: int
algo: str
tensorop: List[int]
split_k_serial_: int
split_k_parallel_: int
shuffle_type: str
element_per_access_a: int
element_per_access_b: int
element_per_access_c: int
access_per_vector: int
def __init__(self) -> None: ...
def __repr__(self) -> str: ...
@property
def split_k_serial(self) -> bool: ...
@split_k_serial.setter
def split_k_serial(self, val: bool) -> None:
"""
Args:
val:
"""
...
@property
def split_k_parallel(self) -> bool: ...
@split_k_parallel.setter
def split_k_parallel(self, val: bool) -> None:
"""
Args:
val:
"""
...
def check_valid(self) -> None: ...
@property
def trans_a(self) -> bool: ...
@trans_a.setter
def trans_a(self, val: bool) -> None:
"""
Args:
val:
"""
...
@property
def trans_b(self) -> bool: ...
@trans_b.setter
def trans_b(self, val: bool) -> None:
"""
Args:
val:
"""
...
@property
def trans_c(self) -> bool: ...
@trans_c.setter
def trans_c(self, val: bool) -> None:
"""
Args:
val:
"""
...
def query_workspace_size(self, m: int, n: int, k: int, split_k_slices: int) -> int:
"""
Args:
m:
n:
k:
split_k_slices:
"""
...
def supported(self, m: int, n: int, k: int) -> bool:
"""
Args:
m:
n:
k:
"""
...
def supported_ldx(self, lda: int, ldb: int, ldc: int) -> bool:
"""
Args:
lda:
ldb:
ldc:
"""
...
class GemmParams:
algo_desp: GemmAlgoDesp
split_k_slices: int
workspace: Tensor = Tensor()
a_inds: Tensor = Tensor()
b_inds: Tensor = Tensor()
c_inds: Tensor = Tensor()
alpha: float
beta: float
stream: int
timer: CUDAKernelTimer
def __init__(self, timer: CUDAKernelTimer = CUDAKernelTimer(False)) -> None:
"""
Args:
timer:
"""
...
def check_valid(self) -> None: ...
@property
def a(self) -> Tensor: ...
@a.setter
def a(self, val: Tensor) -> None:
"""
Args:
val:
"""
...
@property
def b(self) -> Tensor: ...
@b.setter
def b(self, val: Tensor) -> None:
"""
Args:
val:
"""
...
@property
def c(self) -> Tensor: ...
@c.setter
def c(self, val: Tensor) -> None:
"""
Args:
val:
"""
...
class GemmMainUnitTest: class GemmMainUnitTest:
@staticmethod @staticmethod
def get_all_algo_desp() -> List[GemmAlgoDesp]: ... def get_all_algo_desp() -> List[Any]: ...
@staticmethod @staticmethod
def extract_mnk(a_shape: List[int], b_shape: List[int], trans_a: bool, trans_b: bool, trans_c: bool, shuffle_type: str = "NS", a_inds_shape: List[int] = [], b_inds_shape: List[int] = [], c_inds_shape: List[int] = []) -> Tuple[int, int, int]: def extract_mnk(a_shape: List[int], b_shape: List[int], trans_a: bool, trans_b: bool, trans_c: bool, shuffle_type: str = "0", a_inds_shape: List[int] = [], b_inds_shape: List[int] = [], c_inds_shape: List[int] = []) -> Tuple[int, int, int]:
""" """
Args: Args:
a_shape: a_shape:
......
...@@ -104,6 +104,8 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -104,6 +104,8 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
self.add_member("map_8_8", "tsl::robin_map<uint64_t, uint64_t>") self.add_member("map_8_8", "tsl::robin_map<uint64_t, uint64_t>")
self.add_pybind_member("insert_count_", "int64_t", prop_name="insert_count", readwrite=False) self.add_pybind_member("insert_count_", "int64_t", prop_name="insert_count", readwrite=False)
self.valid_hash_key_types = [dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64]
@pccm.pybind.mark @pccm.pybind.mark
@pccm.constructor @pccm.constructor
def ctor(self): def ctor(self):
...@@ -163,11 +165,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -163,11 +165,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code.raw(f""" code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream); auto custream = reinterpret_cast<cudaStream_t>(stream);
""") """)
for k_items in _dispatch_ints(code, [4, 8], "keys_data.itemsize()"): for k_items in _dispatch(code, self.valid_hash_key_types, "keys_data.dtype()"):
code.raw(f""" code.raw(f"""
using K = tv::hash::itemsize_to_unsigned_t<{k_items}>; using K = {k_items};
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data()); K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
""") """)
for v_items in _dispatch_ints(code, [4, 8], "values_data.itemsize()"): for v_items in _dispatch_ints(code, [4, 8], "values_data.itemsize()"):
...@@ -176,10 +176,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -176,10 +176,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
V* value_data_ptr = reinterpret_cast<V*>(values_data.raw_data()); V* value_data_ptr = reinterpret_cast<V*>(values_data.raw_data());
using table_t = using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>, tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey, false>; tv::hash::default_empty_key_v<K>, false>;
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0)); table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
tv::cuda::Launch launcher(table.size(), custream); tv::cuda::Launch launcher(table.size(), custream);
launcher(tv::hash::clear_table_split<table_t>, table); launcher(tv::hash::clear_map_kernel_split<table_t>, table);
""") """)
return code return code
...@@ -201,9 +201,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -201,9 +201,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
int64_t value_after_insert = keys.dim(0) + insert_count_; int64_t value_after_insert = keys.dim(0) + insert_count_;
TV_ASSERT_RT_ERR(value_after_insert < keys_data.dim(0), "inserted count exceed maximum hash size"); TV_ASSERT_RT_ERR(value_after_insert < keys_data.dim(0), "inserted count exceed maximum hash size");
insert_count_ += keys.dim(0); insert_count_ += keys.dim(0);
TV_ASSERT_RT_ERR(keys.dtype() == keys_data.dtype(), "keys dtype not equal to", keys_data.dtype());
}} }}
auto N = keys.dim(0); auto N = keys.dim(0);
TV_ASSERT_RT_ERR(keys.itemsize() == key_itemsize_, "keys itemsize not equal to", key_itemsize_);
if (!values.empty()){{ if (!values.empty()){{
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_); TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(keys.dim(0) == values.dim(0), "number of key and value must same"); TV_ASSERT_RT_ERR(keys.dim(0) == values.dim(0), "number of key and value must same");
...@@ -231,10 +231,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -231,10 +231,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code.raw(f""" code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream); auto custream = reinterpret_cast<cudaStream_t>(stream);
""") """)
for k_items in _dispatch_ints(code, [4, 8], "keys_data.itemsize()"): for k_items in _dispatch(code, self.valid_hash_key_types, "keys_data.dtype()"):
code.raw(f""" code.raw(f"""
using K = tv::hash::itemsize_to_unsigned_t<{k_items}>; using K = {k_items};
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data()); K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
const K* key_ptr = reinterpret_cast<const K*>(keys.raw_data()); const K* key_ptr = reinterpret_cast<const K*>(keys.raw_data());
...@@ -248,7 +247,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -248,7 +247,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
using table_t = using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>, tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey, false>; tv::hash::default_empty_key_v<K>, false>;
tv::cuda::Launch launcher(N, custream); tv::cuda::Launch launcher(N, custream);
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0)); table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
launcher(tv::hash::insert_split<table_t>, table, key_ptr, value_ptr, size_t(N)); launcher(tv::hash::insert_split<table_t>, table, key_ptr, value_ptr, size_t(N));
...@@ -279,6 +278,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -279,6 +278,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_); TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(N == values.dim(0) && is_empty.dim(0) == N, "number of key and value must same"); TV_ASSERT_RT_ERR(N == values.dim(0) && is_empty.dim(0) == N, "number of key and value must same");
auto is_empty_ptr = is_empty.data_ptr<uint8_t>(); auto is_empty_ptr = is_empty.data_ptr<uint8_t>();
if (!is_cpu){{
TV_ASSERT_RT_ERR(keys.dtype() == keys_data.dtype(), "keys dtype not equal to", keys_data.dtype());
}}
""") """)
with code.if_("is_cpu"): with code.if_("is_cpu"):
map_name = "cpu_map" map_name = "cpu_map"
...@@ -304,10 +306,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -304,10 +306,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code.raw(f""" code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream); auto custream = reinterpret_cast<cudaStream_t>(stream);
""") """)
for k_items in _dispatch_ints(code, [4, 8], "keys_data.itemsize()"): for k_items in _dispatch(code, self.valid_hash_key_types, "keys_data.dtype()"):
code.raw(f""" code.raw(f"""
using K = tv::hash::itemsize_to_unsigned_t<{k_items}>; using K = {k_items};
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data()); K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_ptr = reinterpret_cast<K*>(keys.raw_data()); K* key_ptr = reinterpret_cast<K*>(keys.raw_data());
...@@ -319,7 +320,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -319,7 +320,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
V* value_ptr = reinterpret_cast<V*>(values.raw_data()); V* value_ptr = reinterpret_cast<V*>(values.raw_data());
using table_t = using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>, tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey, false>; tv::hash::default_empty_key_v<K>, false>;
tv::cuda::Launch launcher(N, custream); tv::cuda::Launch launcher(N, custream);
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0)); table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
launcher(tv::hash::query_split<table_t>, table, key_ptr, value_ptr, is_empty_ptr, size_t(N)); launcher(tv::hash::query_split<table_t>, table, key_ptr, value_ptr, is_empty_ptr, size_t(N));
...@@ -361,11 +362,12 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -361,11 +362,12 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
TV_ASSERT_RT_ERR(count.device() == 0, "count must be cuda"); TV_ASSERT_RT_ERR(count.device() == 0, "count must be cuda");
auto custream = reinterpret_cast<cudaStream_t>(stream); auto custream = reinterpret_cast<cudaStream_t>(stream);
""") """)
for k_items in _dispatch_ints(code, [4, 8], "keys_data.itemsize()"): for k_items in _dispatch(code, self.valid_hash_key_types, "keys_data.dtype()"):
code.raw(f""" code.raw(f"""
using K = tv::hash::itemsize_to_unsigned_t<{k_items}>; using K = {k_items};
constexpr K kEmptyKey = std::numeric_limits<K>::max(); using Kunsigned = tv::hash::itemsize_to_unsigned_t<sizeof(K)>;
auto count_ptr = count.data_ptr<K>();
auto count_ptr = count.data_ptr<Kunsigned>();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data()); K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
""") """)
...@@ -376,10 +378,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -376,10 +378,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
V* value_data_ptr = reinterpret_cast<V*>(values_data.raw_data()); V* value_data_ptr = reinterpret_cast<V*>(values_data.raw_data());
using table_t = using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>, tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey, false>; tv::hash::default_empty_key_v<K>, false>;
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0)); table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
tv::cuda::Launch launcher(table.size(), custream); tv::cuda::Launch launcher(table.size(), custream);
launcher(tv::hash::assign_arange_split<table_t, K>, table, count_ptr); launcher(tv::hash::assign_arange_split<table_t, Kunsigned>, table, count_ptr);
""") """)
else: else:
code.raw(f""" code.raw(f"""
...@@ -426,7 +428,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -426,7 +428,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
TV_ASSERT_RT_ERR(keys.itemsize() == key_itemsize_, "keys itemsize not equal to", key_itemsize_); TV_ASSERT_RT_ERR(keys.itemsize() == key_itemsize_, "keys itemsize not equal to", key_itemsize_);
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_); TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(N == values.dim(0), "number of key and value must same"); TV_ASSERT_RT_ERR(N == values.dim(0), "number of key and value must same");
if (!is_cpu){{
TV_ASSERT_RT_ERR(keys.dtype() == keys_data.dtype(), "keys dtype not equal to", keys_data.dtype());
}}
""") """)
with code.if_("is_cpu"): with code.if_("is_cpu"):
map_name = "cpu_map" map_name = "cpu_map"
...@@ -450,12 +454,12 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -450,12 +454,12 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code.raw(f""" code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream); auto custream = reinterpret_cast<cudaStream_t>(stream);
""") """)
for k_items in _dispatch_ints(code, [4, 8], "keys_data.itemsize()"): for k_items in _dispatch(code, self.valid_hash_key_types, "keys_data.dtype()"):
code.raw(f""" code.raw(f"""
using K = tv::hash::itemsize_to_unsigned_t<{k_items}>; using K = {k_items};
auto count_ptr = count.data_ptr<K>(); using Kunsigned = tv::hash::itemsize_to_unsigned_t<sizeof(K)>;
auto count_ptr = count.data_ptr<Kunsigned>();
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data()); K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_ptr = reinterpret_cast<K*>(keys.raw_data()); K* key_ptr = reinterpret_cast<K*>(keys.raw_data());
...@@ -467,10 +471,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -467,10 +471,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
V* value_ptr = reinterpret_cast<V*>(values.raw_data()); V* value_ptr = reinterpret_cast<V*>(values.raw_data());
using table_t = using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>, tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey, false>; tv::hash::default_empty_key_v<K>, false>;
tv::cuda::Launch launcher(N, custream); tv::cuda::Launch launcher(N, custream);
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0)); table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
launcher(tv::hash::iterate_table_split<table_t, K>, table, key_ptr, value_ptr, size_t(N), count_ptr); launcher(tv::hash::iterate_table_split<table_t, Kunsigned>, table, key_ptr, value_ptr, size_t(N), count_ptr);
""") """)
else: else:
code.raw(f""" code.raw(f"""
...@@ -523,10 +527,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -523,10 +527,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code.raw(f""" code.raw(f"""
auto custream = reinterpret_cast<cudaStream_t>(stream); auto custream = reinterpret_cast<cudaStream_t>(stream);
""") """)
for k_items in _dispatch_ints(code, [4, 8], "keys_data.itemsize()"): for k_items in _dispatch(code, self.valid_hash_key_types, "keys_data.dtype()"):
code.raw(f""" code.raw(f"""
using K = tv::hash::itemsize_to_unsigned_t<{k_items}>; using K = {k_items};
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data()); K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
const K* key_ptr = reinterpret_cast<const K*>(keys.raw_data()); const K* key_ptr = reinterpret_cast<const K*>(keys.raw_data());
...@@ -538,7 +541,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin): ...@@ -538,7 +541,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
const V* value_ptr = reinterpret_cast<const V*>(values.raw_data()); const V* value_ptr = reinterpret_cast<const V*>(values.raw_data());
using table_t = using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>, tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey, false>; tv::hash::default_empty_key_v<K>, false>;
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0)); table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
tv::cuda::Launch launcher(N, custream); tv::cuda::Launch launcher(N, custream);
launcher(insert_exist_keys_kernel<table_t>, table, key_ptr, value_ptr, is_empty_ptr, size_t(N)); launcher(insert_exist_keys_kernel<table_t>, table, key_ptr, value_ptr, is_empty_ptr, size_t(N));
......
This diff is collapsed.
import pccm
from cumm.common import TensorView, TensorViewCPU, TensorViewKernel, ThrustLib
class ExternalAllocatorGuard(pccm.Class):
def __init__(self):
super().__init__()
self.add_dependency(TensorView)
self.add_member("tensor", "tv::Tensor")
self.add_member("free_func", "std::function<void(tv::Tensor)>")
@pccm.constructor
def ctor(self):
code = pccm.code()
code.arg("ten", "tv::Tensor")
code.arg("free_func", "std::function<void(tv::Tensor)>")
code.ctor_init("tensor", "ten")
code.ctor_init("free_func", "free_func")
return code
@pccm.constructor
def dctor(self):
code = pccm.code()
return code
@pccm.destructor
def dtor(self):
code = pccm.code()
code.raw(f"""
if (!tensor.empty() && free_func){{
free_func(tensor);
}}
""")
return code
class ExternalAllocator(pccm.Class):
def __init__(self):
super().__init__()
self.add_dependency(TensorView, ExternalAllocatorGuard)
self.use_shared = True
self.ptr_type = "unique"
if self.use_shared:
self.ptr_type = "shared"
self.add_typedef("guard_t", f"std::{self.ptr_type}_ptr<ExternalAllocatorGuard>")
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True, pure_virtual=True)
def zeros(self):
code = pccm.code()
code.arg("name", "std::string")
code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int")
code.arg("device", "int")
return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True, pure_virtual=True)
def empty(self):
code = pccm.code()
code.arg("name", "std::string")
code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int")
code.arg("device", "int")
return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True, pure_virtual=True)
def full_int(self):
code = pccm.code()
code.arg("name", "std::string")
code.arg("shape", "std::vector<int64_t>")
code.arg("value", "int")
code.arg("dtype", "int")
code.arg("device", "int")
return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True, pure_virtual=True)
def full_float(self):
code = pccm.code()
code.arg("name", "std::string")
code.arg("shape", "std::vector<int64_t>")
code.arg("value", "float")
code.arg("dtype", "int")
code.arg("device", "int")
return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True, pure_virtual=True)
def free(self):
code = pccm.code()
code.arg("ten", "tv::Tensor")
return code
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True, pure_virtual=True)
def free_noexcept(self):
code = pccm.code()
code.arg("ten", "tv::Tensor")
return code
@pccm.member_function
def zeros_guard(self):
code = pccm.code()
code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int")
code.arg("device", "int")
code.raw(f"""
// "" means temp memory
auto ten = zeros("", shape, dtype, device);
return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
""")
return code.ret(f"std::{self.ptr_type}_ptr<ExternalAllocatorGuard>")
@pccm.member_function
def empty_guard(self):
code = pccm.code()
code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int")
code.arg("device", "int")
code.raw(f"""
auto ten = empty("", shape, dtype, device);
return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
""")
return code.ret(f"std::{self.ptr_type}_ptr<ExternalAllocatorGuard>")
@pccm.member_function
def full_int_guard(self):
code = pccm.code()
code.arg("shape", "std::vector<int64_t>")
code.arg("value", "int")
code.arg("dtype", "int")
code.arg("device", "int")
code.raw(f"""
auto ten = full_int("", shape, value, dtype, device);
return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
""")
return code.ret(f"std::{self.ptr_type}_ptr<ExternalAllocatorGuard>")
@pccm.member_function
def full_float_guard(self):
code = pccm.code()
code.arg("shape", "std::vector<int64_t>")
code.arg("value", "int")
code.arg("dtype", "int")
code.arg("device", "int")
code.raw(f"""
auto ten = full_float("", shape, value, dtype, device);
return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor t){{
this->free(t);
}});
""")
return code.ret(f"std::{self.ptr_type}_ptr<ExternalAllocatorGuard>")
class ThrustAllocator(pccm.Class):
def __init__(self):
super().__init__()
self.add_dependency(TensorView, ExternalAllocator)
self.add_include("functional", "memory")
self.add_member("allocator_", "ExternalAllocator&",)
self.add_typedef("value_type", "char")
@pccm.constructor
def ctor(self):
code = pccm.code()
code.arg("allocator", "ExternalAllocator&")
code.ctor_init("allocator_", "allocator")
return code
@pccm.member_function
def allocate(self):
code = pccm.FunctionCode()
code.arg("num_bytes", "std::ptrdiff_t")
code.ret("char*")
code.raw(f"""
auto ten = allocator_.empty("", {{num_bytes}}, tv::uint8, 0);
return reinterpret_cast<char*>(ten.raw_data());
""")
return code
@pccm.member_function
def deallocate(self):
code = pccm.FunctionCode()
code.arg("ptr", "char *")
code.arg("num_bytes", "size_t")
code.raw(f"""
return allocator_.free_noexcept(tv::from_blob(ptr, {{num_bytes}}, tv::uint8, 0));
""")
return code
import pccm
from cumm.gemm.main import GemmMainUnitTest
from cumm.conv.main import ConvMainUnitTest
from .alloc import ExternalAllocator
from spconv.core import ConvAlgo
from spconv.constants import SpconvAllocatorKeys
from cumm.constants import CUMM_CPU_ONLY_BUILD
from cumm.common import GemmBasicHost, TensorView, NlohmannJson
class GemmTuneResult(pccm.Class, pccm.pybind.PybindClassMixin):
def __init__(self):
super().__init__()
self.add_dependency(GemmBasicHost, TensorView)
self.add_pybind_member("algo_desp", "tv::gemm::GemmAlgoDesp")
self.add_pybind_member("arch", "std::tuple<int, int>")
self.add_pybind_member("splitk", "int")
@pccm.pybind.mark
@pccm.member_function
def is_valid(self):
code = pccm.code()
code.raw(f"return splitk > 0 && std::get<0>(arch) > 0")
return code
@pccm.pybind.mark
@pccm.constructor
def defaultctor(self):
code = pccm.code()
code.ctor_init("algo_desp", "tv::gemm::GemmAlgoDesp()")
code.ctor_init("arch", "std::make_tuple(-1, -1)")
code.ctor_init("splitk", "-1")
return code
@pccm.pybind.mark
@pccm.constructor
def ctor(self):
code = pccm.code()
code.arg("algo_desp",
"tv::gemm::GemmAlgoDesp",
pyanno="cumm.tensorview.gemm.GemmAlgoDesp")
code.arg("arch", "std::tuple<int, int>")
code.arg("splitk", "int")
code.ctor_init("algo_desp", "algo_desp")
code.ctor_init("arch", "arch")
code.ctor_init("splitk", "splitk")
return code
class ConvTuneResult(pccm.Class, pccm.pybind.PybindClassMixin):
def __init__(self):
super().__init__()
self.add_dependency(GemmBasicHost, TensorView)
self.add_pybind_member("algo_desp", "tv::gemm::ConvAlgoDesp")
self.add_pybind_member("arch", "std::tuple<int, int>")
self.add_pybind_member("splitk", "int")
@pccm.pybind.mark
@pccm.constructor
def defaultctor(self):
code = pccm.code()
code.ctor_init("algo_desp", "tv::gemm::ConvAlgoDesp()")
code.ctor_init("arch", "std::make_tuple(-1, -1)")
code.ctor_init("splitk", "-1")
return code
@pccm.pybind.mark
@pccm.constructor
def ctor(self):
code = pccm.code()
code.arg("algo_desp",
"tv::gemm::ConvAlgoDesp",
pyanno="cumm.tensorview.gemm.ConvAlgoDesp")
code.arg("arch", "std::tuple<int, int>")
code.arg("splitk", "int")
code.ctor_init("algo_desp", "algo_desp")
code.ctor_init("arch", "arch")
code.ctor_init("splitk", "splitk")
return code
@pccm.pybind.mark
@pccm.member_function
def is_valid(self):
code = pccm.code()
code.raw(f"return splitk > 0 && std::get<0>(arch) > 0")
return code
class GemmTunerSimple(pccm.ParameterizedClass):
def __init__(self, gemm_cu: GemmMainUnitTest, conv_cu: ConvMainUnitTest):
super().__init__()
self.add_dependency(ExternalAllocator, GemmTuneResult,
ConvTuneResult, TensorView)
self.add_param_class("gemm", gemm_cu, "GemmMain")
self.add_param_class("conv", conv_cu, "ConvMain")
self.add_include("tensorview/utility/tuplehash.h")
self.add_member("desps_", "std::vector<tv::gemm::GemmAlgoDesp>")
self.add_member("nvrtc_progs_", "std::unordered_map<std::string, tv::NVRTCProgram>")
self.add_member("nvrtc_caches_", "std::unordered_map<std::tuple<std::string, int, int, std::uintptr_t>, tv::NVRTCModule>")
@pccm.pybind.mark
@pccm.constructor
def ctor(self):
code = pccm.code()
code.arg("desps", "std::vector<tv::gemm::GemmAlgoDesp>")
code.arg("nvrtc_progs", "std::unordered_map<std::string, std::string>")
code.ctor_init("desps_", "desps")
code.raw(f"""
for (auto& v : nvrtc_progs){{
const uint8_t* code_ptr = reinterpret_cast<const uint8_t*>(v.second.c_str());
nvrtc_progs_.insert(v.first, tv::NVRTCProgram::from_binary(code_ptr, v.second.size()));
}}
""")
return code
@pccm.member_function
def get_all_available(self):
code = pccm.code()
code.arg("a, b, c", "tv::Tensor")
code.arg("trans_a, trans_b, trans_c", "bool")
code.arg("arch", "std::tuple<int, int>")
code.arg("nvrtc_progs", "std::unordered_map<std::string, std::string>")
code.ctor_init("desps_", "desps")
code.raw(f"""
for (auto& v : nvrtc_progs){{
const uint8_t* code_ptr = reinterpret_cast<const uint8_t*>(v.second.c_str());
nvrtc_progs_.insert(v.first, tv::NVRTCProgram::from_binary(code_ptr, v.second.size()));
}}
""")
return code
class ConvGemmOps(pccm.ParameterizedClass):
def __init__(self, gemm_cu: GemmMainUnitTest, conv_cu: ConvMainUnitTest):
super().__init__()
self.add_dependency(ExternalAllocator, GemmTuneResult,
ConvTuneResult)
self.add_param_class("gemm", gemm_cu, "GemmMain")
self.add_param_class("conv", conv_cu, "ConvMain")
@pccm.pybind.mark
@pccm.static_function
def indice_conv(self):
"""1. this function need to take a out features
that from subm first mm.
2. this function don't support CPU.
"""
code = pccm.code()
code.arg("allocator", "ExternalAllocator&")
code.arg("out_features_after_mm", "tv::Tensor")
code.arg("features, filters, indice_pairs", "tv::Tensor")
code.arg("indice_pair_num", "tv::Tensor")
code.arg("num_activate_out", "int")
code.arg("inverse", "bool", "false")
code.arg("subm", "bool", "false")
code.arg("algo", "int", f"{ConvAlgo.Native.value}")
code.arg("filter_hwio", "bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
throw std::runtime_error("this function can only be used with CUDA.")
""")
return code.ret("tv::Tensor")
code.raw(f"""
TV_ASSERT_RT_ERR(!features.is_cpu(), "this function don't support cpu.")
int out_channel;
if (filter_hwio){{
out_channel = filters.dim(-1);
}}else{{
out_channel = filters.dim(-2);
}}
filters = filters.view(-1, filters.dim(-2), filters.dim(-1));
int kv = filters.dim(0);
int kv_center = kv / 2;
tv::Tensor out_features;
if (kv == 1 && subm){{
return;
}}
auto indice_pair_num_cpu = indice_pair_num.cpu();
auto indice_pair_num_cpu_ptr = indice_pair_num_cpu.data_ptr<int>();
int maxnhot = 0;
bool all_zero = true;
for (int i = 0; i < kv; ++i){{
if (indice_pair_num_cpu_ptr[i] != 0){{
all_zero = false;
maxnhot = std::max(maxnhot, indice_pair_num_cpu_ptr[i]);
}}
}}
if (subm && all_zero){{
return;
}}
bool inited = subm;
auto a = features;
auto c = out_features;
auto pair_in = indice_pairs[int(inverse)];
auto pair_out = indice_pairs[int(!inverse)];
""")
return code
...@@ -23,13 +23,8 @@ class OMPLib(pccm.Class): ...@@ -23,13 +23,8 @@ class OMPLib(pccm.Class):
self.add_dependency(TensorView) self.add_dependency(TensorView)
self.add_include("tensorview/parallel/all.h") self.add_include("tensorview/parallel/all.h")
if compat.InWindows: if compat.InWindows:
self.build_meta.add_cflags("cl", "/openmp") self.build_meta.add_public_cflags("cl", "/openmp")
else: else:
self.build_meta.add_cflags("g++", "-fopenmp") self.build_meta.add_public_cflags("g++", "-fopenmp")
self.build_meta.add_cflags("clang++", "-fopenmp") self.build_meta.add_public_cflags("clang++", "-fopenmp")
if "g++" not in self.build_meta.compiler_to_ldflags: self.build_meta.add_ldflags("g++,clang++", "-fopenmp")
self.build_meta.compiler_to_ldflags["g++"] = []
self.build_meta.compiler_to_ldflags["g++"].extend(["-fopenmp"])
if "clang++" not in self.build_meta.compiler_to_ldflags:
self.build_meta.compiler_to_ldflags["clang++"] = []
self.build_meta.compiler_to_ldflags["clang++"].extend(["-fopenmp"])
This diff is collapsed.
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import contextlib import contextlib
from cumm.conv.bases import ConvEnum
from cumm.gemm.core.metaarray import MetaArray, seq from cumm.gemm.core.metaarray import MetaArray, seq
from cumm import dtypes from cumm import dtypes
import pccm import pccm
...@@ -202,14 +201,14 @@ class IndiceMaxPool(pccm.Class): ...@@ -202,14 +201,14 @@ class IndiceMaxPool(pccm.Class):
// if a value is found, other value won't be executed. // if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value; int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures; int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0)); dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0); dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream); launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}); }});
if (!found){{ if (!found){{
int NumFeatures = 16; int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures; int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0)); dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0); dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream); launcher = tv::cuda::Launch(blocks, threads, cudastream);
}} }}
...@@ -244,14 +243,14 @@ class IndiceMaxPool(pccm.Class): ...@@ -244,14 +243,14 @@ class IndiceMaxPool(pccm.Class):
// if a value is found, other value won't be executed. // if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value; int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures; int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0)); dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0); dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream); launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}); }});
if (!found){{ if (!found){{
int NumFeatures = 16; int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures; int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0)); dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0); dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream); launcher = tv::cuda::Launch(blocks, threads, cudastream);
}} }}
...@@ -287,14 +286,14 @@ class IndiceMaxPool(pccm.Class): ...@@ -287,14 +286,14 @@ class IndiceMaxPool(pccm.Class):
// if a value is found, other value won't be executed. // if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value; int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures; int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0)); dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0); dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream); launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}); }});
if (!found){{ if (!found){{
int NumFeatures = 16; int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures; int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0)); dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0); dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream); launcher = tv::cuda::Launch(blocks, threads, cudastream);
}} }}
...@@ -331,14 +330,14 @@ class IndiceMaxPool(pccm.Class): ...@@ -331,14 +330,14 @@ class IndiceMaxPool(pccm.Class):
// if a value is found, other value won't be executed. // if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value; int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures; int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0)); dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0); dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream); launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}); }});
if (!found){{ if (!found){{
int NumFeatures = 16; int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures; int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0)); dim3 blocks(tv::div_up(out.dim(1), int64_t(NumFeatures)), tv::div_up(nhot, int64_t(Num0)));
dim3 threads(NumFeatures, Num0); dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream); launcher = tv::cuda::Launch(blocks, threads, cudastream);
}} }}
......
...@@ -126,6 +126,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -126,6 +126,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
super().__init__() super().__init__()
self.add_dependency(TensorView, TensorViewHashKernel) self.add_dependency(TensorView, TensorViewHashKernel)
self.add_param_class("layout_ns", layout, "Layout") self.add_param_class("layout_ns", layout, "Layout")
self.add_include("tensorview/hash/ops.h")
self.dtype = dtype self.dtype = dtype
self.ndim = ndim self.ndim = ndim
self.zyx = zyx self.zyx = zyx
...@@ -447,7 +448,7 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -447,7 +448,7 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
TV_ASSERT_RT_ERR(point_indice_data.dim(0) >= points.dim(0), "point_indice_data too small") TV_ASSERT_RT_ERR(point_indice_data.dim(0) >= points.dim(0), "point_indice_data too small")
num_per_voxel.zero_(ctx); num_per_voxel.zero_(ctx);
table_t hash = table_t(hashdata.data_ptr<pair_t>(), expected_hash_data_num); table_t hash = table_t(hashdata.data_ptr<pair_t>(), expected_hash_data_num);
hash.clear(custream); tv::hash::clear_map(hash, custream);
auto launcher = tv::cuda::Launch(points.dim(0), custream); auto launcher = tv::cuda::Launch(points.dim(0), custream);
launcher(kernel::build_hash_table<table_t>, hash, points.data_ptr<const {self.dtype}>(), launcher(kernel::build_hash_table<table_t>, hash, points.data_ptr<const {self.dtype}>(),
point_indice_data.data_ptr<int64_t>(), point_indice_data.data_ptr<int64_t>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment