Commit 899008fa authored by yan.yan's avatar yan.yan
Browse files

working on c++ only

parent f78575ea
<!--
Copyright 2021 Yan Yan
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# How to develop spconv 2.x
## First step
spconv 2.x is written in a unique c++ framework ```pccm```. read [pccm guide]() to learn how to use ```pccm```.
It's recommend to uninstall spconv and cumm installed by pip, then install spconv and cumm both in editable mode (```pip install -e .```)
## Architecture
\ No newline at end of file
......@@ -159,6 +159,9 @@ if disable_jit is not None and disable_jit == "1":
from spconv.csrc.utils import BoxOps
from spconv.csrc.hash.core import HashTable
from cumm.common import CompileInfo
from spconv.csrc.sparse.alloc import ExternalAllocator
from spconv.csrc.sparse.convops import GemmTunerSimple, ExternalSpconvMatmul
from spconv.csrc.sparse.convops import ConvTunerSimple, ConvGemmOps
cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS)
convcu = ConvMainUnitTest(IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_TURING_PARAMS)
......@@ -172,14 +175,30 @@ if disable_jit is not None and disable_jit == "1":
std = "c++14"
else:
std = "c++17"
cus = [cu, convcu, SpconvOps(), BoxOps(), HashTable(), CompileInfo()]
if CUMM_CPU_ONLY_BUILD:
cus = [SpconvOps(), BoxOps(), HashTable(), CompileInfo()]
gemmtuner = GemmTunerSimple(cu)
gemmtuner.namespace = "csrc.sparse.convops.gemmops"
convtuner = ConvTunerSimple(convcu)
convtuner.namespace = "csrc.sparse.convops.convops"
convops = ConvGemmOps(gemmtuner, convtuner)
convops.namespace = "csrc.sparse.convops.spops"
else:
gemmtuner = GemmTunerSimple(None)
gemmtuner.namespace = "csrc.sparse.convops.gemmops"
convtuner = ConvTunerSimple(None)
convtuner.namespace = "csrc.sparse.convops.convops"
convops = ConvGemmOps(gemmtuner, convtuner)
convops.namespace = "csrc.sparse.convops.spops"
cus = [gemmtuner, convtuner,
convops, SpconvOps(), BoxOps(), HashTable(), CompileInfo(),
ExternalAllocator(),
ExternalSpconvMatmul()]
if not CUMM_CPU_ONLY_BUILD:
cus.extend([cu, convcu])
ext_modules: List[Extension] = [
PCCMExtension(cus,
"spconv/core_cc",
Path(__file__).resolve().parent / "spconv",
objects_folder="objects",
std=std,
disable_pch=True,
verbose=True)
......
......@@ -37,7 +37,7 @@ from cumm import dtypes
from spconv.constants import (NDIM_DONT_CARE, SPCONV_BWD_SPLITK,
SPCONV_NVRTC_MODE, SPCONV_DEBUG_NVRTC_KERNELS)
from spconv.core import ALL_IMPGEMM_PARAMS, AlgoHint, ConvAlgo
from spconv.core import ALL_IMPGEMM_PARAMS, AlgoHint, ConvAlgo, ALL_NATIVE_PARAMS
from spconv.core_cc.cumm.conv.main import ConvMainUnitTest
from spconv.core_cc.cumm.gemm.main import GemmMainUnitTest
from spconv.cppconstants import COMPILED_CUDA_ARCHS
......@@ -49,14 +49,17 @@ from spconv import algocore
from cumm.conv.main import gen_gemm_kernels as gen_conv_kernels
from cumm.gemm.main import gen_gemm_kernels
from spconv.core_cc.csrc.sparse.convops import GemmTuneResult, ConvTuneResult
from spconv.core_cc.csrc.sparse.convops.gemmops import GemmTunerSimple as GemmTunerSimpleBase
from spconv.core_cc.csrc.sparse.convops.convops import ConvTunerSimple as ConvTunerSimpleBase
ALL_ALGO_DESPS = GemmMainUnitTest.get_all_algo_desp()
ALL_CONV_ALGO_DESPS = ConvMainUnitTest.get_all_conv_algo_desp()
_GEMM_STATIC_KEY = Tuple[bool, bool, bool, int, int, int, str, str]
_GEMM_STATIC_KEY = Tuple[bool, bool, bool, int, int, int, int, str]
class SimpleGemmAlgoMeta:
def __init__(self, tile_ms: List[int], tile_ns: List[int],
tile_ks: List[int],
tile_shape_to_algos: Dict[int, List[int]]) -> None:
......@@ -67,19 +70,29 @@ class SimpleGemmAlgoMeta:
class BestAlgoByProfile:
def __init__(self, algo_desp: GemmAlgoDesp, arch: Tuple[int, int], splitk: int = 1) -> None:
def __init__(self,
algo_desp: GemmAlgoDesp,
arch: Tuple[int, int],
splitk: int = 1) -> None:
self.algo_desp = algo_desp
self.splitk = splitk
self.arch = arch
class BestConvAlgoByProfile:
def __init__(self, algo_desp: ConvAlgoDesp, arch: Tuple[int, int], splitk: int = 1) -> None:
def __init__(self,
algo_desp: ConvAlgoDesp,
arch: Tuple[int, int],
splitk: int = 1) -> None:
self.algo_desp = algo_desp
self.splitk = splitk
self.arch = arch
def _get_nvrtc_params(mod: CummNVRTCModule, ker: Union[GemmKernel, ConvKernel], kernel_name: str):
def _get_nvrtc_params(mod: CummNVRTCModule, ker: Union[GemmKernel, ConvKernel],
kernel_name: str):
nvrtc_mode = SPCONV_NVRTC_MODE
nvrtc_params = tv.gemm.NVRTCParams()
nvrtc_params.cumodule = mod.get_cpp_object()
......@@ -89,8 +102,7 @@ def _get_nvrtc_params(mod: CummNVRTCModule, ker: Union[GemmKernel, ConvKernel],
ns = ker.namespace
if nvrtc_mode == NVRTCMode.DynamicParallism:
nvrtc_params.kernel_name = mod.get_lowered_name(
f"{ns}::nvrtc_kernel")
nvrtc_params.kernel_name = mod.get_lowered_name(f"{ns}::nvrtc_kernel")
elif nvrtc_mode == NVRTCMode.KernelAndCPU:
nvrtc_params.kernel_name = mod.get_lowered_name(f"{ns}::{kernel_name}")
......@@ -101,8 +113,10 @@ def _get_nvrtc_params(mod: CummNVRTCModule, ker: Union[GemmKernel, ConvKernel],
nvrtc_params.param_storage = tv.empty([nvrtc_params.param_size],
tv.uint8, 0)
nvrtc_params.param_storage_cpu = tv.empty(
[nvrtc_params.param_size], tv.uint8, -1, pinned=True)
nvrtc_params.param_storage_cpu = tv.empty([nvrtc_params.param_size],
tv.uint8,
-1,
pinned=True)
elif nvrtc_mode == NVRTCMode.Direct:
nvrtc_params.kernel_name = mod.get_lowered_name(f"{ns}::{kernel_name}")
......@@ -120,9 +134,84 @@ def _get_nvrtc_params(mod: CummNVRTCModule, ker: Union[GemmKernel, ConvKernel],
raise NotImplementedError
return nvrtc_params
class GemmTunerSimple(GemmTunerSimpleBase):
def __init__(self, desps: List[GemmAlgoDesp]) -> None:
super().__init__(desps)
self._nvrtc_caches: Dict[Tuple[str, Tuple[int, int], int], NVRTCParams] = {}
def _compile_nvrtc_module(self, desp: GemmAlgoDesp):
params = algocore.get_gemm_param_from_desp(desp)
kernel = gen_gemm_kernels(params, SPCONV_NVRTC_MODE)
kernel.namespace = "spconv"
custom_names = []
if SPCONV_NVRTC_MODE == NVRTCMode.ConstantMemory:
custom_names = [
f"&{kernel.namespace}::{NVRTCConstants.CONSTANT_PARAM_KEY}"
]
cudadevrt = ""
if SPCONV_NVRTC_MODE == NVRTCMode.DynamicParallism:
cudadevrt_p = get_cudadevrt_path()
assert cudadevrt_p is not None, "DynamicParallism must have cudadevrt"
cudadevrt = str(cudadevrt_p)
mod = CummNVRTCModule([kernel],
cudadevrt_path=cudadevrt,
custom_names=custom_names)
mod.load()
return mod, kernel
def cached_get_nvrtc_params(self, desp: GemmAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams:
key = (str(desp), arch, stream_int)
if key in self._nvrtc_caches:
return self._nvrtc_caches[key]
mod, ker = self._compile_nvrtc_module(desp)
nvrtc_params = _get_nvrtc_params(mod, ker, "gemm_kernel")
self._nvrtc_caches[key] = nvrtc_params
return nvrtc_params
class ConvTunerSimple(ConvTunerSimpleBase):
def __init__(self, desps: List[ConvAlgoDesp]) -> None:
super().__init__(desps)
self._nvrtc_caches: Dict[Tuple[str, Tuple[int, int], int], NVRTCParams] = {}
def _compile_nvrtc_module(self, desp: ConvAlgoDesp):
params = algocore.get_conv_param_from_desp(desp)
kernel = gen_conv_kernels(params, SPCONV_NVRTC_MODE)
kernel.namespace = "spconv"
custom_names = []
if SPCONV_NVRTC_MODE == NVRTCMode.ConstantMemory:
custom_names = [
f"&{kernel.namespace}::{NVRTCConstants.CONSTANT_PARAM_KEY}"
]
cudadevrt = ""
if SPCONV_NVRTC_MODE == NVRTCMode.DynamicParallism:
cudadevrt_p = get_cudadevrt_path()
assert cudadevrt_p is not None, "DynamicParallism must have cudadevrt"
cudadevrt = str(cudadevrt_p)
mod = CummNVRTCModule([kernel],
cudadevrt_path=cudadevrt,
verbose=False,
custom_names=custom_names)
mod.load()
return mod, kernel
def cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams:
key = (str(desp), arch, stream_int)
if key in self._nvrtc_caches:
return self._nvrtc_caches[key]
mod, ker = self._compile_nvrtc_module(desp)
print(f"Can't find algo {desp} in prebuilt. compile with nvrtc...")
nvrtc_params = _get_nvrtc_params(mod, ker, "conv_kernel")
self._nvrtc_caches[key] = nvrtc_params
return nvrtc_params
class SimpleGemm:
def __init__(self, prebuilt_desps: List[GemmAlgoDesp]) -> None:
all_desps = [algocore.get_conv_algo_desp_from_param(p) for p in ALL_IMPGEMM_PARAMS]
all_desps = [
algocore.get_gemm_algo_desp_from_param(p)
for p in ALL_NATIVE_PARAMS
]
self.prebuilt_desps = prebuilt_desps
self.prebuilt_desp_names = {str(d) for d in prebuilt_desps}
if SPCONV_DEBUG_NVRTC_KERNELS:
......@@ -178,7 +267,9 @@ class SimpleGemm:
kernel.namespace = "spconv"
custom_names = []
if SPCONV_NVRTC_MODE == NVRTCMode.ConstantMemory:
custom_names = [f"&{kernel.namespace}::{NVRTCConstants.CONSTANT_PARAM_KEY}"]
custom_names = [
f"&{kernel.namespace}::{NVRTCConstants.CONSTANT_PARAM_KEY}"
]
cudadevrt = ""
if SPCONV_NVRTC_MODE == NVRTCMode.DynamicParallism:
cudadevrt_p = get_cudadevrt_path()
......@@ -186,12 +277,12 @@ class SimpleGemm:
cudadevrt = str(cudadevrt_p)
mod = CummNVRTCModule([kernel],
cudadevrt_path=cudadevrt,
verbose=False,
custom_names=custom_names)
mod.load()
return mod, kernel
def _cached_get_nvrtc_params(self, desp: GemmAlgoDesp, arch: Tuple[int, int]):
def _cached_get_nvrtc_params(self, desp: GemmAlgoDesp, arch: Tuple[int,
int]):
key = (str(desp), arch)
if key in self._nvrtc_caches:
return self._nvrtc_caches[key]
......@@ -218,12 +309,15 @@ class SimpleGemm:
trans_c = False
avail_algos = get_available_algo_str_from_arch(arch)
finally_algos: List[GemmAlgoDesp] = []
# print(self.static_key_to_desps)
for algo in avail_algos:
static_key = (trans_a, trans_b, trans_c, a.dtype, b.dtype, c.dtype,
shuffle_type.value, algo)
# print(static_key)
desps = self.static_key_to_desps.get(static_key, None)
if desps is None or len(desps) == 0:
continue
# print(desps)
for desp in desps:
# skip volta tensor op since it is very slow in architectures except volta.
if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value:
......@@ -430,6 +524,7 @@ class SimpleGemm:
best_scatter_params = (-1, -1, -1, -1)
all_profile_res: List[BestAlgoByProfile] = []
# print(avail)
for desp in avail:
c_.zero_whole_storage_()
split_k_slices = 1
......@@ -466,7 +561,8 @@ class SimpleGemm:
times.append(np.mean(this_times[1:]))
spk_speeds.append(times[-1])
all_profile_res.append(BestAlgoByProfile(desp, arch, splitk=spk))
all_profile_res.append(
BestAlgoByProfile(desp, arch, splitk=spk))
min_time = 1000
min_idx = -1
......@@ -490,8 +586,7 @@ class SimpleGemm:
return res, min_time
def run_with_tuned_result(
self,
def run_with_tuned_result(self,
profile_res: BestAlgoByProfile,
a: tv.Tensor,
b: tv.Tensor,
......@@ -501,7 +596,7 @@ class SimpleGemm:
trans_c: bool,
arch: Tuple[int, int],
stream: int,
shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle,
shuffle_type: ShuffleStrideType,
a_inds: tv.Tensor = tv.Tensor(),
b_inds: tv.Tensor = tv.Tensor(),
c_inds: tv.Tensor = tv.Tensor(),
......@@ -510,7 +605,8 @@ class SimpleGemm:
beta: float = 0.0,
gather_data: tv.Tensor = tv.Tensor(),
workspace: tv.Tensor = tv.Tensor(),
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
timer: CUDAKernelTimer = CUDAKernelTimer(False),
force_nvrtc: bool = False):
m, n, k = GemmMainUnitTest.extract_mnk(a.shape, b.shape, trans_a,
trans_b, trans_c,
shuffle_type.value,
......@@ -526,8 +622,10 @@ class SimpleGemm:
if profile_res.splitk > 1:
split_k_slices = profile_res.splitk
params = GemmParams()
if algo_desp.is_nvrtc and str(algo_desp) not in self.prebuilt_desp_names:
params.nvrtc_params = self._cached_get_nvrtc_params(algo_desp, profile_res.arch)
is_not_static = str(algo_desp) not in self.prebuilt_desp_names
if algo_desp.is_nvrtc and (is_not_static or force_nvrtc):
params.nvrtc_params = self._cached_get_nvrtc_params(
algo_desp, profile_res.arch)
params.a = a
params.b = b
......@@ -569,8 +667,12 @@ _CONV_STATIC_KEY = Tuple[int, int, int, int, int, int, int, int, int, str, int]
class SimpleConv:
def __init__(self, prebuilt_desps: List[ConvAlgoDesp]) -> None:
all_desps = [algocore.get_conv_algo_desp_from_param(p) for p in ALL_IMPGEMM_PARAMS]
all_desps = [
algocore.get_conv_algo_desp_from_param(p)
for p in ALL_IMPGEMM_PARAMS
]
self.prebuilt_desps = prebuilt_desps
self.prebuilt_desp_names = {str(d) for d in prebuilt_desps}
self.prebuilt_desp_names.clear()
......@@ -650,6 +752,7 @@ class SimpleConv:
use_f32_as_accum = weight.dim(0) * kv > 128 * 27
else:
use_f32_as_accum = fp32_accum
use_f32_as_accum = False
for algo in avail_algos:
static_key = (layout_i.layout_type.value,
layout_w.layout_type.value,
......@@ -664,7 +767,6 @@ class SimpleConv:
if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value:
continue
if arch >= (7, 0) and is_fp16:
# skip simt fp16 kernels if we have tensor core
if desp.algo == GemmAlgo.Simt:
continue
if use_f32_as_accum:
......@@ -675,6 +777,7 @@ class SimpleConv:
ldw = weight.dim(-1)
ldo = out.dim(-1)
mask_width_valid = True
if desp.op_type == ConvOpType.kBackwardWeight.value:
assert mask_width > 0
mask_width_valid = mask_width % desp.tile_shape[2] == 0
......@@ -722,7 +825,9 @@ class SimpleConv:
kernel.namespace = "spconv"
custom_names = []
if SPCONV_NVRTC_MODE == NVRTCMode.ConstantMemory:
custom_names = [f"&{kernel.namespace}::{NVRTCConstants.CONSTANT_PARAM_KEY}"]
custom_names = [
f"&{kernel.namespace}::{NVRTCConstants.CONSTANT_PARAM_KEY}"
]
cudadevrt = ""
if SPCONV_NVRTC_MODE == NVRTCMode.DynamicParallism:
cudadevrt_p = get_cudadevrt_path()
......@@ -735,10 +840,12 @@ class SimpleConv:
mod.load()
return mod, kernel
def _cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int, int]):
def _cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int,
int]):
key = (str(desp), arch)
if key in self._nvrtc_caches:
return self._nvrtc_caches[key]
print(f"Can't find algo {desp} in prebuilt. compile with nvrtc...")
mod, ker = self._compile_nvrtc_module(desp)
nvrtc_params = _get_nvrtc_params(mod, ker, "conv_kernel")
self._nvrtc_caches[key] = nvrtc_params
......@@ -795,8 +902,8 @@ class SimpleConv:
params.indices = indices
params.mask = mask
params.mask_output = mask_output
if op_type == ConvOpType.kBackwardWeight:
assert not mask_output.empty()
# if op_type == ConvOpType.kBackwardWeight:
# assert not mask_output.empty()
if op_type == ConvOpType.kBackwardInput:
params.reverse_mask = reverse_mask
params.mask_filter = mask_filter
......@@ -808,20 +915,20 @@ class SimpleConv:
spk_speeds = []
for spk in splitk_tests:
this_times = []
for j in range(3):
GemmMainUnitTest.stream_synchronize(stream)
t = time.time()
for j in range(4):
params.split_k_slices = spk
if desp.is_nvrtc and str(desp) not in self.prebuilt_desp_names:
with tv.measure_duration(stream=stream) as measure:
if desp.is_nvrtc and str(
desp) not in self.prebuilt_desp_names:
tv.gemm.run_nvrtc_conv_kernel(params)
else:
ConvMainUnitTest.implicit_gemm2(params)
GemmMainUnitTest.stream_synchronize(stream)
this_times.append(time.time() - t)
this_times.append(measure.duration)
times.append(np.mean(this_times[1:]))
spk_speeds.append(times[-1])
all_profile_res.append(BestConvAlgoByProfile(desp, arch, splitk=spk))
all_profile_res.append(
BestConvAlgoByProfile(desp, arch, splitk=spk))
if not all_profile_res:
raise ValueError("can't find suitable algorithm for", op_type)
min_time = 1000
......@@ -865,7 +972,8 @@ class SimpleConv:
stream: int = 0,
workspace: tv.Tensor = tv.Tensor(),
verbose: bool = False,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
timer: CUDAKernelTimer = CUDAKernelTimer(False),
force_nvrtc: bool = False):
channel_k = output.dim(1)
channel_c = inp.dim(1)
# GemmMainUnitTest.stream_synchronize(stream)
......@@ -879,13 +987,17 @@ class SimpleConv:
else:
op_type_value = op_type.value
params = ConvParams(NDIM_DONT_CARE, ConvOpTypeCpp(op_type_value))
if algo_desp.is_nvrtc and str(algo_desp) not in self.prebuilt_desp_names:
params.nvrtc_params = self._cached_get_nvrtc_params(algo_desp, profile_res.arch)
is_not_static = str(
algo_desp) not in self.prebuilt_desp_names
if algo_desp.is_nvrtc and (is_not_static or force_nvrtc):
params.nvrtc_params = self._cached_get_nvrtc_params(
algo_desp, profile_res.arch)
params.conv_algo_desp = profile_res.algo_desp
params.input = inp
params.verbose = verbose
params.weight = weight.view([channel_k, -1, channel_c])
params.output = output
params.split_k_slices = split_k_slices
params.alpha = alpha
params.beta = beta
......@@ -893,6 +1005,7 @@ class SimpleConv:
params.mask_argsort = mask_argsort
params.indices = indices
params.mask = mask
params.mask_filter = mask_filter
params.mask_width = mask_width
params.mask_filter = mask_filter
......@@ -919,6 +1032,13 @@ class SimpleConv:
GEMM = SimpleGemm(ALL_ALGO_DESPS)
CONV = SimpleConv(ALL_CONV_ALGO_DESPS)
GEMM_CPP = GemmTunerSimple([
algocore.get_gemm_algo_desp_from_param(p)
for p in ALL_NATIVE_PARAMS])
CONV_CPP = ConvTunerSimple([
algocore.get_conv_algo_desp_from_param(p)
for p in ALL_IMPGEMM_PARAMS])
if __name__ == "__main__":
print(len(ALL_CONV_ALGO_DESPS))
print(ALL_CONV_ALGO_DESPS[0])
......@@ -24,8 +24,8 @@ from cumm.tensorview.gemm import ConvLayoutType as ConvLayoutTypeCpp
from cumm.tensorview.gemm import ShuffleStrideType as ShuffleStrideTypeCpp
from cumm.tensorview.gemm import ConvParams, GemmAlgoDesp, GemmParams
from cumm.gemm.main import GemmAlgoParams
from cumm.conv.main import ConvAlgoParams, ConvIterAlgo
from cumm.gemm.main import GemmAlgoParams, gen_gemm_kernels
from cumm.conv.main import ConvAlgoParams, ConvIterAlgo, gen_gemm_kernels as gen_conv_kernels
from cumm import dtypes
from cumm.conv.bases import (NCHW, NHWC, ConvIterAlgo, ConvLayout,
ConvLayoutType, ConvMode, ConvOpType)
......@@ -56,10 +56,15 @@ def _assign_gemm_desp_props(desp: Union[ConvAlgoDesp, GemmAlgoDesp],
desp.access_per_vector = p.access_per_vector
desp.is_nvrtc = p.is_nvrtc
def get_gemm_algo_desp_from_param(p: GemmAlgoParams):
desp = GemmAlgoDesp()
_assign_gemm_desp_props(desp, p)
# here we must generate kernel for element-per-access data
ker = gen_gemm_kernels(p)
desp.element_per_access_a = ker.input_spec.input_iter_a.element_per_acc
desp.element_per_access_b = ker.input_spec.input_iter_b.element_per_acc
desp.element_per_access_c = ker.output_spec.out_iter.element_per_acc
return desp
......@@ -78,6 +83,10 @@ def get_conv_algo_desp_from_param(p: ConvAlgoParams):
desp.interleave_o = p.layout_desp_output.interleave
desp.mask_sparse = p.mask_sparse
desp.increment_k_first = p.increment_k_first
ker = gen_conv_kernels(p)
desp.element_per_access_a = ker.input_spec.input_iter_a.element_per_acc
desp.element_per_access_b = ker.input_spec.input_iter_b.element_per_acc
desp.element_per_access_c = ker.output_spec.out_iter.element_per_acc
return desp
......@@ -106,6 +115,7 @@ def _assign_gemm_params(desp: Union[ConvAlgoDesp, GemmAlgoDesp],
p.is_nvrtc = desp.is_nvrtc
def get_gemm_param_from_desp(desp: GemmAlgoDesp):
p = GemmAlgoParams((0, 0, 0), (0, 0, 0), 0, "s8,s8,s8,s8,s8", False, False,
False, GemmAlgo.Simt)
......
"""Benchmark MinkowskiEngine
"""
from spconv.benchmark.core import get_voxel_data
import time
from pathlib import Path
import numpy as np
import torch
from torch import nn
from spconv.core import ConvAlgo
from cumm import dtypes
from spconv.test_utils import params_grid
_DTYPE_TO_TORCH_DTYPE = {
dtypes.float32: torch.float32,
dtypes.float16: torch.float16,
}
def bench_me_basic(dtype_str: str):
dtype = dtypes.get_dtype_by_shortcut(dtype_str)
if dtype not in _DTYPE_TO_TORCH_DTYPE:
raise NotImplementedError("only support bench f32 and f16 for now")
torch_dtype = _DTYPE_TO_TORCH_DTYPE[dtype]
"""Benchmark torchsparse
"""
from spconv.benchmark.core import get_voxel_data
import time
from pathlib import Path
import numpy as np
import torch
from torch import nn
from spconv.core import ConvAlgo
from cumm import dtypes
from spconv.test_utils import params_grid
_DTYPE_TO_TORCH_DTYPE = {
dtypes.float32: torch.float32,
dtypes.float16: torch.float16,
}
def bench_torchsparse_basic(dtype_str: str):
dtype = dtypes.get_dtype_by_shortcut(dtype_str)
if dtype not in _DTYPE_TO_TORCH_DTYPE:
raise NotImplementedError("only support bench f32 and f16 for now")
torch_dtype = _DTYPE_TO_TORCH_DTYPE[dtype]
......@@ -13,6 +13,7 @@
# limitations under the License.
from pathlib import Path
from typing import List
import pccm
from pccm.utils import project_is_editable, project_is_installed
......@@ -32,6 +33,10 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
from spconv.csrc.sparse.alloc import ExternalAllocator
from spconv.csrc.utils import BoxOps
from spconv.csrc.hash.core import HashTable
from spconv.csrc.sparse.convops import GemmTunerSimple, ExternalSpconvMatmul
from spconv.csrc.sparse.convops import ConvTunerSimple, ConvGemmOps
from spconv.csrc.sparse.convops import SimpleExternalSpconvMatmul
all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS
all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle))
cu = GemmMainUnitTest(all_shuffle)
......@@ -41,8 +46,35 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp))
convcu = ConvMainUnitTest(all_imp)
convcu.namespace = "cumm.conv.main"
pccm.builder.build_pybind([cu, convcu, SpconvOps(), BoxOps(), HashTable(), CompileInfo(), ExternalAllocator()],
gemmtuner = GemmTunerSimple(cu)
gemmtuner.namespace = "csrc.sparse.convops.gemmops"
convtuner = ConvTunerSimple(convcu)
convtuner.namespace = "csrc.sparse.convops.convops"
convops = ConvGemmOps(gemmtuner, convtuner)
convops.namespace = "csrc.sparse.convops.spops"
cus = [
cu, convcu, gemmtuner, convtuner,
convops,
SpconvOps(),
BoxOps(),
HashTable(),
CompileInfo(),
ExternalAllocator(),
ExternalSpconvMatmul(),
SimpleExternalSpconvMatmul(),
]
pccm.builder.build_pybind(cus,
PACKAGE_ROOT / "core_cc",
namespace_root=PACKAGE_ROOT,
load_library=False)
load_library=False,
verbose=True)
# cus_dev: List[pccm.Class] = [
# ]
# pccm.builder.build_pybind(cus_dev,
# PACKAGE_ROOT / "core_cc_dev",
# namespace_root=PACKAGE_ROOT,
# load_library=False,
# verbose=True)
......@@ -30,6 +30,7 @@ if _filter_hwio_env is not None:
raise NotImplementedError("SPCONV_FILTER_HWIO is deprecated. use SPCONV_SAVED_WEIGHT_LAYOUT instead.")
DISABLE_JIT = os.getenv("SPCONV_DISABLE_JIT", "0") == "1"
NDIM_DONT_CARE = 3
FILTER_HWIO = False
......@@ -59,8 +60,10 @@ SPCONV_BWD_SPLITK = list(map(int, os.getenv("SPCONV_BWD_SPLITK", "1,2,4,8,16,32,
SPCONV_NVRTC_MODE = NVRTCMode.ConstantMemory
SPCONV_DEBUG_NVRTC_KERNELS = False
SPCONV_DEBUG_CPP_ONLY = project_is_editable(PACKAGE_NAME)
class SpconvAllocatorKeys:
class AllocKeys:
Pair = "Pair"
IndiceNumPerLoc = "IndiceNumPerLoc"
PairMask = "PairMask"
......@@ -72,5 +75,31 @@ class SpconvAllocatorKeys:
# MaskArgSortFwd = "MaskArgSortFwd"
MaskArgSortBwd = "MaskArgSortBwd"
MaskOutputFwd = "MaskOutputFwd"
OutFeatures = "OutFeatures"
Features = "Features"
Filters = "Filters"
OutBp = "OutBp"
DIn = "DIn"
DFilters = "DFilters"
InpBuffer = "InpBuffer"
OutBuffer = "OutBuffer"
IndicePairsUniq = "IndicePairsUniq"
IndicePairsUniqBackup = "IndicePairsUniqBackup"
HashKOrKV = "HashKOrKV"
HashV = "HashV"
ThrustTemp = "ThrustTemp"
SPCONV_DEBUG_WEIGHT = False
SPCONV_CPP_INDICE_PAIRS = True
SPCONV_CPP_INDICE_PAIRS_IGEMM = True
SPCONV_CPP_GEMM = True
\ No newline at end of file
......@@ -16,9 +16,10 @@ from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgo
from cumm.gemm import kernel
from typing import List
from cumm.gemm.algospec.core import TensorOp
from cumm.conv.main import gen_gemm_params as gen_conv_params, ConvFwdAndBwdInput, ConvBwdWeight, ConvIterAlgo, GemmAlgo
from cumm.conv.main import gen_gemm_params as gen_conv_params, ConvFwdAndBwdInput, ConvBwdWeight, ConvFwd, ConvIterAlgo, GemmAlgo
from cumm.conv.bases import (NCHW, NHWC, ConvIterAlgo, ConvLayout,
ConvLayoutType, ConvMode, ConvOpType)
from spconv.algocore import get_gemm_algo_desp_from_param
from spconv.constants import NDIM_DONT_CARE
......@@ -402,32 +403,6 @@ IMPLGEMM_SIMT_PARAMS = [
increment_k_first=True,
access_per_vector=1),
]
IMPLGEMM_SIMT_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 32, 16), (32, 32, 8),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Simt,
None,
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
*gen_conv_params(ConvBwdWeight, (64, 32, 16), (32, 32, 8),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Simt,
None,
mask_sparse=True,
increment_k_first=True,
access_per_vector=1),
]
IMPLGEMM_VOLTA_PARAMS = [
......@@ -693,6 +668,181 @@ IMPLGEMM_TURING_PARAMS = [
# NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOp((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1),
# gen_conv_params(ConvFwdAndBwdInput, )
# all int8 kernels use nvrtc.
*gen_conv_params(ConvFwd, (32, 32, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (32, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (32, 32, 64), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (32, 64, 64), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (64, 64, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (64, 32, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (128, 128, 64), (64, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
*gen_conv_params(ConvFwd, (64, 128, 64), (32, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["s8,s8,s8,s32,s32"],
NHWC,
NHWC,
NHWC,
GemmAlgo.Turing,
TensorOp((8, 8, 16)),
mask_sparse=True,
increment_k_first=True,
access_per_vector=1,
is_nvrtc=False),
# *gen_conv_params(ConvFwd, (32, 32, 32), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
# *gen_conv_params(ConvFwd, (32, 64, 32), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
# *gen_conv_params(ConvFwd, (32, 32, 64), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
]
ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS
......
......@@ -48,7 +48,7 @@ class SpconvOps:
"""
...
@staticmethod
def generate_conv_inds_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
def generate_conv_inds_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0, use_bound_algo: bool = False) -> int:
"""
Args:
indices:
......@@ -58,6 +58,7 @@ class SpconvOps:
indice_pairs_uniq:
indice_pairs_uniq_before_sort:
out_inds:
indice_num_per_loc:
num_out_act:
batch_size:
output_dims:
......@@ -68,6 +69,7 @@ class SpconvOps:
dilation:
transposed:
stream_int:
use_bound_algo:
"""
...
@staticmethod
......@@ -191,6 +193,31 @@ class SpconvOps:
"""
...
@staticmethod
def indice_maxpool(out_features: Tensor, features: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, num_activate_out: int, stream: int = 0) -> None:
"""
Args:
out_features:
features:
indice_pairs:
indice_pair_num:
num_activate_out:
stream:
"""
...
@staticmethod
def indice_maxpool_backward(din: Tensor, features: Tensor, out_features: Tensor, out_bp: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, stream: int = 0) -> None:
"""
Args:
din:
features:
out_features:
out_bp:
indice_pairs:
indice_pair_num:
stream:
"""
...
@staticmethod
def maxpool_implicit_gemm_forward(out: Tensor, inp: Tensor, inds: Tensor, stream: int = 0) -> None:
"""
Args:
......@@ -369,7 +396,18 @@ class SpconvOps:
@staticmethod
def get_int32_max() -> int: ...
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0) -> Tensor:
def get_indice_gen_workspace_size(kv: int, num_act_in: int, num_act_out_bound: int, subm: bool, use_int64_hash_k: bool) -> int:
"""
Args:
kv:
num_act_in:
num_act_out_bound:
subm:
use_int64_hash_k:
"""
...
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1) -> Tuple[Tensor, int]:
"""
Args:
allocator:
......@@ -386,10 +424,11 @@ class SpconvOps:
transposed:
is_train:
stream_int:
num_out_act_bound:
"""
...
@staticmethod
def get_indice_pairs(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, stream_int: int = 0) -> None:
def get_indice_pairs(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, stream_int: int = 0, num_out_act_bound: int = -1) -> int:
"""
Args:
allocator:
......@@ -405,12 +444,6 @@ class SpconvOps:
subm:
transposed:
stream_int:
"""
...
@staticmethod
def test_allocator(allocator) -> None:
"""
Args:
allocator:
num_out_act_bound:
"""
...
......@@ -2,25 +2,29 @@ from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Ty
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
class ExternalAllocator:
def zeros(self, name: str, shape: List[int], dtype: int, device: int) -> Tensor:
def zeros(self, name: str, shape: List[int], dtype: int, device: int, is_temp_memory: bool = False, stream: int = 0) -> Tensor:
"""
Args:
name:
shape:
dtype:
device:
is_temp_memory:
stream:
"""
...
def empty(self, name: str, shape: List[int], dtype: int, device: int) -> Tensor:
def empty(self, name: str, shape: List[int], dtype: int, device: int, is_temp_memory: bool = False, stream: int = 0) -> Tensor:
"""
Args:
name:
shape:
dtype:
device:
is_temp_memory:
stream:
"""
...
def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int) -> Tensor:
def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int, is_temp_memory: bool = False, stream: int = 0) -> Tensor:
"""
Args:
name:
......@@ -28,9 +32,11 @@ class ExternalAllocator:
value:
dtype:
device:
is_temp_memory:
stream:
"""
...
def full_float(self, name: str, shape: List[int], value: float, dtype: int, device: int) -> Tensor:
def full_float(self, name: str, shape: List[int], value: float, dtype: int, device: int, is_temp_memory: bool = False, stream: int = 0) -> Tensor:
"""
Args:
name:
......@@ -38,6 +44,14 @@ class ExternalAllocator:
value:
dtype:
device:
is_temp_memory:
stream:
"""
...
def get_tensor_by_name(self, name: str) -> Tensor:
"""
Args:
name:
"""
...
def free(self, ten: Tensor) -> None:
......
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview.gemm import GemmAlgoDesp
from cumm.tensorview.gemm import ConvAlgoDesp
from cumm.tensorview import Tensor
from ...csrc.sparse.convops import ExternalSpconvMatmul
class GemmTuneResult:
algo_desp: GemmAlgoDesp
arch: Tuple[int, int]
splitk: int
def is_valid(self) -> bool: ...
@overload
def __init__(self) -> None: ...
@overload
def __init__(self, algo_desp: GemmAlgoDesp, arch: Tuple[int, int], splitk: int) -> None:
"""
Args:
algo_desp:
arch:
splitk:
"""
...
class ConvTuneResult:
algo_desp: ConvAlgoDesp
arch: Tuple[int, int]
splitk: int
@overload
def __init__(self) -> None: ...
@overload
def __init__(self, algo_desp: ConvAlgoDesp, arch: Tuple[int, int], splitk: int) -> None:
"""
Args:
algo_desp:
arch:
splitk:
"""
...
def is_valid(self) -> bool: ...
class ExternalSpconvMatmul:
def indice_conv_init_gemm(self, features_n: str, filters_n: str, all_weight_is_krsc: bool, is_kc_not_ck: bool, kv_center: int, out_channel: int, stream_int: int = 0) -> Tensor:
"""
Args:
features_n:
filters_n:
all_weight_is_krsc:
is_kc_not_ck:
kv_center:
out_channel:
stream_int:
"""
...
def indice_conv_cpu_gemm(self, inp_buffer_n: str, out_buffer_n: str, filters_n: str, all_weight_is_krsc: bool, is_kc_not_ck: bool, nhot: int, index: int) -> None:
"""
Args:
inp_buffer_n:
out_buffer_n:
filters_n:
all_weight_is_krsc:
is_kc_not_ck:
nhot:
index:
"""
...
def indice_conv_bwd_init_gemm(self, features_n: str, filters_n: str, out_bp_n: str, dfilters_n: str, all_weight_is_krsc: bool, is_kc_not_ck: bool, kv_center: int, stream_int: int = 0) -> Tensor:
"""
Args:
features_n:
filters_n:
out_bp_n:
dfilters_n:
all_weight_is_krsc:
is_kc_not_ck:
kv_center:
stream_int:
"""
...
def indice_conv_bwd_cpu_gemm(self, inp_buffer_n: str, out_buffer_n: str, filters_n: str, dfilters_n: str, all_weight_is_krsc: bool, is_kc_not_ck: bool, nhot: int, index: int) -> None:
"""
Args:
inp_buffer_n:
out_buffer_n:
filters_n:
dfilters_n:
all_weight_is_krsc:
is_kc_not_ck:
nhot:
index:
"""
...
class SimpleExternalSpconvMatmul(ExternalSpconvMatmul):
def __init__(self, alloc) -> None:
"""
Args:
alloc:
"""
...
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview.gemm import ConvAlgoDesp
from cumm.tensorview import Tensor
from cumm.tensorview.gemm import NVRTCParams
from spconv.core_cc.csrc.sparse.convops import ConvTuneResult
from cumm.tensorview import CUDAKernelTimer
class ConvTunerSimple:
def __init__(self, desps: List[ConvAlgoDesp]) -> None:
"""
Args:
desps:
"""
...
@staticmethod
def get_available_algo_str_from_arch(arch: Tuple[int, int]) -> List[str]:
"""
Args:
arch:
"""
...
def get_all_available(self, inp: Tensor, weight: Tensor, out: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], op_type: int, mask_width: int, auto_fp32_accum: bool, fp32_accum: bool) -> List[ConvAlgoDesp]:
"""
Args:
inp:
weight:
out:
layout_i:
layout_w:
layout_o:
interleave_i:
interleave_w:
interleave_o:
arch:
op_type:
mask_width:
auto_fp32_accum:
fp32_accum:
"""
...
def cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams:
"""
Args:
desp:
arch:
stream_int:
"""
...
def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5) -> Tuple[ConvTuneResult, float]:
"""
Args:
op_type:
inp:
weight:
output:
layout_i:
layout_w:
layout_o:
interleave_i:
interleave_w:
interleave_o:
arch:
mask:
mask_argsort:
indices:
reverse_mask:
mask_filter:
mask_width:
mask_output:
alpha:
beta:
stream_int:
auto_fp32_accum:
fp32_accum:
num_run:
"""
...
def get_tuned_algo(self, op_type: int, i_dtype: int, w_dtype: int, o_dtype: int, k: int, c: int, arch: Tuple[int, int], mask_width: int = -1) -> Tuple[Any, bool]:
"""
Args:
op_type:
i_dtype:
w_dtype:
o_dtype:
k:
c:
arch:
mask_width:
"""
...
def run_with_tuned_result(self, profile_res, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, mask: Tensor, mask_argsort: Tensor, mask_output: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, workspace: Tensor = Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(false), force_nvrtc: bool = False) -> None:
"""
Args:
profile_res:
op_type:
inp:
weight:
output:
mask:
mask_argsort:
mask_output:
indices:
reverse_mask:
mask_filter:
mask_width:
alpha:
beta:
stream_int:
workspace:
verbose:
timer:
force_nvrtc:
"""
...
def query_workspace_size(self, desp: ConvAlgoDesp, splitk: int, op_type: int, N: int, C: int, K: int, kv: int) -> int:
"""
Args:
desp:
splitk:
op_type:
N:
C:
K:
kv:
"""
...
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview.gemm import GemmAlgoDesp
from cumm.tensorview import Tensor
from cumm.tensorview.gemm import NVRTCParams
from spconv.core_cc.csrc.sparse.convops import GemmTuneResult
from cumm.tensorview import CUDAKernelTimer
class GemmTunerSimple:
def __init__(self, desps: List[GemmAlgoDesp]) -> None:
"""
Args:
desps:
"""
...
@staticmethod
def get_available_algo_str_from_arch(arch: Tuple[int, int]) -> List[str]:
"""
Args:
arch:
"""
...
def get_all_available(self, a: Tensor, b: Tensor, c: Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: int) -> List[GemmAlgoDesp]:
"""
Args:
a:
b:
c:
trans_a:
trans_b:
trans_c:
arch:
shuffle_type:
"""
...
def cached_get_nvrtc_params(self, desp: GemmAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams:
"""
Args:
desp:
arch:
stream_int:
"""
...
def tune_and_cache(self, a: Tensor, b: Tensor, c: Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: int, a_inds: Tensor, b_inds: Tensor, c_inds: Tensor, hint: int = 0, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, num_run: int = 5) -> Tuple[GemmTuneResult, float]:
"""
Args:
a:
b:
c:
trans_a:
trans_b:
trans_c:
arch:
shuffle_type:
a_inds:
b_inds:
c_inds:
hint:
alpha:
beta:
stream_int:
num_run:
"""
...
def get_tuned_algo(self, a_dtype: int, b_dtype: int, c_dtype: int, a_shape: List[int], b_shape: List[int], c_shape: List[int], trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: int, a_inds_shape: List[int], b_inds_shape: List[int], c_inds_shape: List[int], hint: int = 0) -> Tuple[Any, bool]:
"""
Args:
a_dtype:
b_dtype:
c_dtype:
a_shape:
b_shape:
c_shape:
trans_a:
trans_b:
trans_c:
arch:
shuffle_type:
a_inds_shape:
b_inds_shape:
c_inds_shape:
hint:
"""
...
def run_with_tuned_result(self, profile_res, a: Tensor, b: Tensor, c: Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], stream_int: int, shuffle_type: int, a_inds: Tensor, b_inds: Tensor, c_inds: Tensor, hint: int = 0, alpha: float = 1.0, beta: float = 0.0, workspace: Tensor = Tensor(), timer: CUDAKernelTimer = CUDAKernelTimer(False), force_nvrtc: bool = False) -> None:
"""
Args:
profile_res:
a:
b:
c:
trans_a:
trans_b:
trans_c:
arch:
stream_int:
shuffle_type:
a_inds:
b_inds:
c_inds:
hint:
alpha:
beta:
workspace:
timer:
force_nvrtc:
"""
...
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
from cumm.tensorview import CUDAKernelTimer
class ConvGemmOps:
@staticmethod
def get_compute_capability(index: int = -1) -> Tuple[int, int]:
"""
Args:
index:
"""
...
@staticmethod
def indice_conv(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, num_activate_out: int, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
"""
1. this function need to take a out features
that from subm first mm.
2. this function don't support CPU.
Args:
allocator:
ext_mm:
gemm_tuner:
all_w_is_krsc:
filter_hwio:
features:
filters:
indice_pairs:
indice_pair_num:
num_activate_out:
inverse:
subm:
algo:
stream_int:
"""
...
@staticmethod
def indice_conv_backward(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, out_bp: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
"""
Args:
allocator:
ext_mm:
gemm_tuner:
all_w_is_krsc:
filter_hwio:
features:
filters:
out_bp:
indice_pairs:
indice_pair_num:
inverse:
subm:
algo:
stream_int:
"""
...
@staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> int:
"""
Args:
allocator:
conv_tuner:
features:
filters:
pair_fwd:
pair_mask_fwd_splits:
mask_argsort_fwd_splits:
num_activate_out:
masks:
is_train:
is_subm:
stream_int:
timer:
auto_fp32_accum:
fp32_accum:
"""
...
@staticmethod
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> None:
"""
Args:
allocator:
conv_tuner:
features:
filters:
out_bp:
pair_fwd:
pair_bwd:
pair_mask_fwd_splits:
pair_mask_bwd_splits:
mask_argsort_fwd_splits:
mask_argsort_bwd_splits:
mask_output_fwd:
masks:
mask_width:
is_subm:
stream_int:
timer:
auto_fp32_accum:
fp32_accum:
"""
...
......@@ -3,3 +3,10 @@ from pccm.stubs import EnumValue, EnumClassValue
class CompileInfo:
@staticmethod
def get_compiled_cuda_arch() -> List[Tuple[int, int]]: ...
@staticmethod
def arch_is_compiled(arch: Tuple[int, int]) -> bool:
"""
Args:
arch:
"""
...
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview.gemm import GemmAlgoDesp
from cumm.tensorview.gemm import GemmParams
class GemmMainUnitTest:
@staticmethod
def get_all_algo_desp() -> List[Any]: ...
def get_all_algo_desp() -> List[GemmAlgoDesp]: ...
@staticmethod
def extract_mnk(a_shape: List[int], b_shape: List[int], trans_a: bool, trans_b: bool, trans_c: bool, shuffle_type: str = "0", a_inds_shape: List[int] = [], b_inds_shape: List[int] = [], c_inds_shape: List[int] = []) -> Tuple[int, int, int]:
def extract_mnk(a_shape: List[int], b_shape: List[int], trans_a: bool, trans_b: bool, trans_c: bool, shuffle_type: int = 0, a_inds_shape: List[int] = [], b_inds_shape: List[int] = [], c_inds_shape: List[int] = []) -> Tuple[int, int, int]:
"""
Args:
a_shape:
......
......@@ -26,7 +26,7 @@ from .indices import SparseConvIndicesKernel, CudaCommonKernel, SparseConvIndice
from .maxpool import IndiceMaxPool, IndiceMaxPoolCPU
from .gather import GatherCPU
from .alloc import ExternalAllocator, ThrustAllocator
from spconv.constants import SpconvAllocatorKeys
from spconv.constants import AllocKeys
class CustomThrustLib(pccm.Class):
def __init__(self):
......@@ -34,7 +34,7 @@ class CustomThrustLib(pccm.Class):
self.add_dependency(ThrustLib)
# https://github.com/NVIDIA/thrust/issues/1401#issuecomment-806403746
if compat.InLinux:
self.build_meta.add_public_cflags("nvcc", "-Xcompiler", "-fno-gnu-unique", "-Xcompiler", "-fvisibility=hidden")
self.build_meta.add_public_cflags("nvcc", "-Xcompiler -fno-gnu-unique", "-Xcompiler -fvisibility=hidden")
class ThrustCustomAllocatorV2(pccm.Class, pccm.pybind.PybindClassMixin):
......@@ -76,6 +76,7 @@ class SpconvOps(pccm.Class):
super().__init__()
self.add_dependency(ThrustCustomAllocatorV2, ExternalAllocator, GemmBasicHost, ThrustAllocator)
self.ndims = [1, 2, 3, 4]
self.cuda_common_kernel = CudaCommonKernel()
for ndim in self.ndims:
p2v = Point2Voxel(dtypes.float32, ndim)
p2v_cpu = Point2VoxelCPU(dtypes.float32, ndim)
......@@ -102,6 +103,11 @@ class SpconvOps(pccm.Class):
indices,
f"SpconvIndices{ndim}D")
for name in dir(AllocKeys):
if not name.startswith("__"):
v = getattr(AllocKeys, name)
self.add_static_const("k" + name, "auto", f"tv::make_const_string({pccm.literal(v)})")
@pccm.pybind.mark
@pccm.static_function
def cumm_version(self):
......@@ -194,12 +200,15 @@ class SpconvOps(pccm.Class):
code = pccm.FunctionCode()
code.arg("indices, hashdata_k, hashdata_v", "tv::Tensor")
code.arg("indice_pairs, indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds", "tv::Tensor")
code.arg("indice_num_per_loc", "tv::Tensor")
code.arg("num_out_act", "int")
code.arg("batch_size", "int")
code.arg("output_dims, input_dims", f"std::vector<int>")
code.arg("ksize, stride, padding, dilation", f"std::vector<int>")
code.arg("transposed", f"bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("use_bound_algo", "bool", "false")
if CUMM_CPU_ONLY_BUILD:
return code.make_invalid()
......@@ -225,9 +234,11 @@ class SpconvOps(pccm.Class):
}}
return SpconvIndices{ndim}D::generate_conv_inds_stage2(indices,
hashdata_k, hashdata_v, indice_pairs,
indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds, num_out_act,
indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds,
indice_num_per_loc, num_out_act,
batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int);
ksize_, stride_, padding_, dilation_, transposed, stream_int,
use_bound_algo);
}}
""")
code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""")
......@@ -481,6 +492,93 @@ class SpconvOps(pccm.Class):
""")
return code
@pccm.pybind.mark
@pccm.cuda.static_function
def indice_maxpool(self):
code = pccm.FunctionCode()
code.arg("out_features, features", "tv::Tensor")
code.arg("indice_pairs", "tv::Tensor")
code.arg("indice_pair_num", "tv::Tensor")
code.arg("num_activate_out", "int")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.add_dependency(IndiceMaxPoolCPU)
if not CUMM_CPU_ONLY_BUILD:
code.add_dependency(IndiceMaxPool)
code.raw(f"""
tv::check_shape(out_features, {{-1, features.dim(1)}});
auto indice_pair_num_cpu = indice_pair_num.cpu();
auto indice_pair_num_cpu_ptr = indice_pair_num_cpu.data_ptr<int>();
""")
with code.for_("int i = 0; i < indice_pair_num.dim(0); ++i"):
code.raw(f"""
int nhot = indice_pair_num_cpu_ptr[i];
nhot = std::min(nhot, int(indice_pairs.dim(2)));
if (nhot <= 0){{
continue;
}}
auto inp_indices = indice_pairs[0][i].slice_first_axis(0, nhot);
auto out_indices = indice_pairs[1][i].slice_first_axis(0, nhot);
if (features.is_cpu()){{
IndiceMaxPoolCPU::forward(out_features, features, out_indices, inp_indices);
}}
""")
if not CUMM_CPU_ONLY_BUILD:
with code.else_():
code.raw(f"""
IndiceMaxPool::forward(out_features, features, out_indices, inp_indices, stream);
""")
else:
code.raw(f"""
TV_THROW_RT_ERR("not implemented in cpu-only spconv!!! ")
""")
return code
@pccm.pybind.mark
@pccm.cuda.static_function
def indice_maxpool_backward(self):
code = pccm.FunctionCode()
code.arg("din, features, out_features, out_bp", "tv::Tensor")
code.arg("indice_pairs", "tv::Tensor")
code.arg("indice_pair_num", "tv::Tensor")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.add_dependency(IndiceMaxPoolCPU)
if not CUMM_CPU_ONLY_BUILD:
code.add_dependency(IndiceMaxPool)
code.raw(f"""
tv::check_shape(din, features.shape());
auto indice_pair_num_cpu = indice_pair_num.cpu();
auto indice_pair_num_cpu_ptr = indice_pair_num_cpu.data_ptr<int>();
""")
with code.for_("int i = 0; i < indice_pair_num.dim(0); ++i"):
code.raw(f"""
int nhot = indice_pair_num_cpu_ptr[i];
nhot = std::min(nhot, int(indice_pairs.dim(2)));
if (nhot <= 0){{
continue;
}}
auto inp_indices = indice_pairs[0][i].slice_first_axis(0, nhot);
auto out_indices = indice_pairs[1][i].slice_first_axis(0, nhot);
if (features.is_cpu()){{
IndiceMaxPoolCPU::backward(out_features, features, out_bp, din, out_indices, inp_indices);
}}
""")
if not CUMM_CPU_ONLY_BUILD:
with code.else_():
code.raw(f"""
IndiceMaxPool::backward(out_features, features, out_bp, din, out_indices, inp_indices, stream);
""")
else:
code.raw(f"""
TV_THROW_RT_ERR("not implemented in cpu-only spconv!!! ")
""")
return code
@pccm.pybind.mark
@pccm.cuda.static_function
def maxpool_implicit_gemm_forward(self):
......@@ -597,7 +695,7 @@ class SpconvOps(pccm.Class):
"""
code.add_dependency(CustomThrustLib, TensorViewKernel)
code.add_param_class("cudakers", CudaCommonKernel())
code.add_param_class("cudakers", self.cuda_common_kernel)
code.raw(f"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
if (indices.empty()){{
......@@ -613,7 +711,7 @@ class SpconvOps(pccm.Class):
auto thrust_ctx = thrust::cuda::par.on(stream_cu);
thrust::stable_sort_by_key(thrust_ctx, ptr_tr, ptr_tr + data.dim(0), ptr_k, SmallOrEqualTo<uint32_t>());
}});
tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0);
// tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0);
return indices;
""")
return code.ret("tv::Tensor")
......@@ -646,7 +744,7 @@ class SpconvOps(pccm.Class):
}}
"""
code.add_dependency(CustomThrustLib, TensorViewKernel)
code.add_param_class("cudakers", CudaCommonKernel())
code.add_param_class("cudakers", self.cuda_common_kernel)
if not use_allocator:
code.raw(f"""
ThrustCustomAllocatorV2 allocator{{alloc_func}};
......@@ -715,7 +813,7 @@ class SpconvOps(pccm.Class):
}}
"""
code.add_dependency(CustomThrustLib, TensorViewKernel)
code.add_param_class("cudakers", CudaCommonKernel())
code.add_param_class("cudakers", self.cuda_common_kernel)
code.raw(f"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
// auto timer = tv::CudaContextTimer<>();
......@@ -774,7 +872,7 @@ class SpconvOps(pccm.Class):
}}
"""
code.add_dependency(CustomThrustLib, TensorViewKernel)
code.add_param_class("cudakers", CudaCommonKernel())
code.add_param_class("cudakers", self.cuda_common_kernel)
if not use_allocator:
code.raw(f"""
ThrustCustomAllocatorV2 allocator{{alloc_func}};
......@@ -1141,6 +1239,26 @@ class SpconvOps(pccm.Class):
""")
return code.ret("int")
@pccm.pybind.mark
@pccm.static_function
def get_indice_gen_workspace_size(self):
code = pccm.code()
code.arg("kv", "size_t")
code.arg("num_act_in", "size_t")
code.arg("num_act_out_bound", "size_t")
code.arg("subm, use_int64_hash_k", "bool")
code.raw(f"""
if (subm){{
return 2 * num_act_in * (use_int64_hash_k ? 2 : 3) * sizeof(int);
}}else{{
size_t pair_single_size = kv * num_act_in;
size_t ind_uniq_and_bkp_size = (pair_single_size + 1) * 2 * (use_int64_hash_k ? sizeof(int64_t) : sizeof(int32_t));
size_t hash_size = 2 * num_act_out_bound * (use_int64_hash_k ? 2 : 3) * sizeof(int);
return ind_uniq_and_bkp_size + hash_size;
}}
""")
return code.ret("std::size_t")
@pccm.pybind.mark
@pccm.static_function
def get_indice_pairs_implicit_gemm(self):
......@@ -1154,6 +1272,8 @@ class SpconvOps(pccm.Class):
code.arg("subm, transposed, is_train", f"bool")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("num_out_act_bound", f"int", "-1")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
throw std::runtime_error("this function can only be used with CUDA.")
......@@ -1192,13 +1312,13 @@ class SpconvOps(pccm.Class):
int mask_split_count = is_mask_split ? 2 : 1;
tv::Tensor pair;
if (subm){{
pair = allocator.full_int({pccm.literal(SpconvAllocatorKeys.Pair)},
pair = allocator.full_int({pccm.literal(AllocKeys.Pair)},
{{2, kv, indices.dim(0)}}, -1, indices.dtype(), indices.device());
}}else{{
pair = allocator.full_int({pccm.literal(SpconvAllocatorKeys.Pair)},
pair = allocator.full_int({pccm.literal(AllocKeys.Pair)},
{{kv, indices.dim(0)}}, -1, indices.dtype(), indices.device());
}}
auto indice_num_per_loc = allocator.zeros({pccm.literal(SpconvAllocatorKeys.IndiceNumPerLoc)},
auto indice_num_per_loc = allocator.zeros({pccm.literal(AllocKeys.IndiceNumPerLoc)},
{{kv}}, indices.dtype(), indices.device());
tv::Tensor mask_tensor = tv::zeros({{mask_split_count}}, tv::uint32, -1);
auto mask_tensor_ptr = mask_tensor.data_ptr<uint32_t>();
......@@ -1213,39 +1333,48 @@ class SpconvOps(pccm.Class):
mask_tensor_ptr[1] = uint32_t(second);
}}
else{{
mask_tensor_ptr[1] = 0xffffffff;
mask_tensor_ptr[0] = 0xffffffff;
}}
tv::Tensor out_inds;
ThrustAllocator thrustalloc(allocator);
int num_act_out = 0;
if (subm){{
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
out_inds = indices;
num_act_out = indices.dim(0);
int num_points = out_inds.dim(0);
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{num_points * 2}}, tv::int64, 0);
hash_v_gurad = allocator.empty_guard({{num_points * 2}}, tv::int32, 0);
hash_k_guard = allocator.empty_guard({{num_points * 2}},
tv::int64, 0, {pccm.literal(AllocKeys.HashKOrKV)});
hash_v_gurad = allocator.empty_guard({{num_points * 2}},
tv::int32, 0, {pccm.literal(AllocKeys.HashV)});
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_points * 2}}, tv::int32, 0);
hash_kv_gurad = allocator.empty_guard({{2, num_points * 2}},
tv::int32, 0, {pccm.literal(AllocKeys.HashKOrKV)});
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
auto pair_mask = allocator.empty({pccm.literal(SpconvAllocatorKeys.PairMask)},
auto pair_mask = allocator.empty({pccm.literal(AllocKeys.PairMask)},
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0);
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation, pair_mask, false, stream_int);
auto mask_argsort = allocator.empty({pccm.literal(SpconvAllocatorKeys.MaskArgSort)},
{{mask_split_count, out_inds.dim(0)}}, tv::uint32, 0);
auto mask_argsort = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)},
{{mask_split_count, out_inds.dim(0)}}, tv::int32, 0);
for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int);
}}
}}else{{
auto pair_bwd = pair;
auto indice_pairs_uniq_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, indice_uniq_dtype, 0);
auto indice_pairs_uniq_bkp_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, indice_uniq_dtype, 0);
auto indice_pairs_uniq_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}},
indice_uniq_dtype, 0, {pccm.literal(AllocKeys.IndicePairsUniq)});
auto indice_pairs_uniq_bkp_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}},
indice_uniq_dtype, 0, {pccm.literal(AllocKeys.IndicePairsUniqBackup)});
auto indice_pairs_uniq = indice_pairs_uniq_guard->tensor;
generate_conv_inds_mask_stage1(indices, pair_bwd, indice_pairs_uniq,
......@@ -1253,28 +1382,34 @@ class SpconvOps(pccm.Class):
stride, padding, dilation, transposed, stream_int);
indice_pairs_uniq_bkp_guard->tensor.copy_(indice_pairs_uniq, tvctx);
// TODO pytorch unique may be faster?
int num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int) - 1;
num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int);
if (num_out_act_bound > 0 && num_act_out > num_out_act_bound){{
num_act_out = num_out_act_bound;
}}
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_act_out);
out_inds = allocator.empty({pccm.literal(SpconvAllocatorKeys.OutIndices)},
out_inds = allocator.empty({pccm.literal(AllocKeys.OutIndices)},
{{num_act_out, indices.dim(1)}}, indices.dtype(), 0);
auto pair_fwd = allocator.full_int({pccm.literal(SpconvAllocatorKeys.PairFwd)},
auto pair_fwd = allocator.full_int({pccm.literal(AllocKeys.PairFwd)},
{{kv, num_act_out}}, -1, indices.dtype(), indices.device());
auto pair_mask_fwd = allocator.zeros({pccm.literal(SpconvAllocatorKeys.PairMask)},
auto pair_mask_fwd = allocator.zeros({pccm.literal(AllocKeys.PairMask)},
{{mask_split_count, num_act_out}}, tv::uint32, 0);
auto pair_mask_bwd = tv::Tensor();
if (is_train){{
pair_mask_bwd = allocator.zeros({pccm.literal(SpconvAllocatorKeys.PairMaskBwd)},
pair_mask_bwd = allocator.zeros({pccm.literal(AllocKeys.PairMaskBwd)},
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0);
}}
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{num_act_out * 2}}, tv::int64, 0);
hash_v_gurad = allocator.empty_guard({{num_act_out * 2}}, tv::int32, 0);
hash_k_guard = allocator.empty_guard({{num_act_out * 2}},
tv::int64, 0, {pccm.literal(AllocKeys.HashKOrKV)});
hash_v_gurad = allocator.empty_guard({{num_act_out * 2}},
tv::int32, 0, {pccm.literal(AllocKeys.HashV)});
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_act_out * 2}}, tv::int32, 0);
hash_kv_gurad = allocator.empty_guard({{2, num_act_out * 2}},
tv::int32, 0, {pccm.literal(AllocKeys.HashKOrKV)});
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
......@@ -1283,23 +1418,24 @@ class SpconvOps(pccm.Class):
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int);
auto mask_argsort_fwd = allocator.empty({pccm.literal(SpconvAllocatorKeys.MaskArgSort)},
{{mask_split_count, out_inds.dim(0)}}, tv::uint32, 0);
auto mask_argsort_fwd = allocator.empty({pccm.literal(AllocKeys.MaskArgSort)},
{{mask_split_count, out_inds.dim(0)}}, tv::int32, 0);
tv::Tensor mask_argsort_bwd = tv::Tensor();
if (is_train){{
mask_argsort_bwd = allocator.zeros({pccm.literal(SpconvAllocatorKeys.MaskArgSortBwd)},
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0);
mask_argsort_bwd = allocator.zeros({pccm.literal(AllocKeys.MaskArgSortBwd)},
{{mask_split_count, indices.dim(0)}}, tv::int32, 0);
}}
if (is_mask_split){{
for (int j = 0; j < mask_split_count; ++j){{
auto mask_tensor_sub = mask_tensor.slice_first_axis(j, j + 1);
if (!is_train){{
sort_1d_by_key_split_allocator_v2(pair_mask_fwd[j], thrustalloc,
mask_tensor[j], mask_argsort_fwd[j], stream_int);
mask_tensor_sub, mask_argsort_fwd[j], stream_int);
}}else{{
sort_1d_by_key_split_allocator_v2(pair_mask_fwd[j], thrustalloc,
mask_tensor[j], mask_argsort_fwd[j], stream_int);
mask_tensor_sub, mask_argsort_fwd[j], stream_int);
sort_1d_by_key_split_allocator_v2(pair_mask_bwd[j], thrustalloc,
mask_tensor[j], mask_argsort_bwd[j], stream_int);
mask_tensor_sub, mask_argsort_bwd[j], stream_int);
}}
}}
}}else{{
......@@ -1314,9 +1450,9 @@ class SpconvOps(pccm.Class):
}}
}}
}}
return mask_tensor;
return std::make_tuple(mask_tensor, num_act_out);
""")
return code.ret("tv::Tensor")
return code.ret("std::tuple<tv::Tensor, int>")
@pccm.pybind.mark
@pccm.static_function
......@@ -1329,15 +1465,12 @@ class SpconvOps(pccm.Class):
code.arg("algo", "int")
code.arg("ksize, stride, padding, dilation, out_padding", f"std::vector<int>")
code.arg("subm, transposed", f"bool")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"""
throw std::runtime_error("this function can only be used with CUDA.")
""")
return code
code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("num_out_act_bound", f"int", "-1")
code.raw(f"""
int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>());
TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32");
auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo);
TV_ASSERT_RT_ERR(conv_algo == tv::gemm::SparseConvAlgo::kNative, "only support kNative");
......@@ -1362,15 +1495,17 @@ class SpconvOps(pccm.Class):
}}
}}
tv::Tensor pair;
pair = allocator.full_int({pccm.literal(SpconvAllocatorKeys.Pair)},
pair = allocator.full_int({pccm.literal(AllocKeys.Pair)},
{{2, kv, indices.dim(0)}}, -1, indices.dtype(), indices.device());
auto indice_num_per_loc = allocator.zeros({pccm.literal(SpconvAllocatorKeys.IndiceNumPerLoc)},
auto indice_num_per_loc = allocator.zeros({pccm.literal(AllocKeys.IndiceNumPerLoc)},
{{kv}}, indices.dtype(), indices.device());
tv::Tensor out_inds;
int num_act_out = -1;
""")
with code.if_("subm"):
code.raw(f"""
num_act_out = indices.dim(0);
if (indices.is_cpu()){{
generate_subm_conv_inds_cpu(indices, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation);
......@@ -1384,12 +1519,15 @@ class SpconvOps(pccm.Class):
int num_points = out_inds.dim(0);
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{num_points * 2}}, tv::int64, 0);
hash_v_gurad = allocator.empty_guard({{num_points * 2}}, tv::int32, 0);
hash_k_guard = allocator.empty_guard({{num_points * 2}},
tv::int64, 0, {pccm.literal(AllocKeys.HashKOrKV)});
hash_v_gurad = allocator.empty_guard({{num_points * 2}},
tv::int32, 0, {pccm.literal(AllocKeys.HashV)});
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_points * 2}}, tv::int32, 0);
hash_kv_gurad = allocator.empty_guard({{2, num_points * 2}},
tv::int32, 0, {pccm.literal(AllocKeys.HashKOrKV)});
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
......@@ -1406,10 +1544,10 @@ class SpconvOps(pccm.Class):
with code.else_():
code.raw(f"""
if (indices.is_cpu()){{
out_inds = allocator.empty({pccm.literal(SpconvAllocatorKeys.OutIndices)},
TV_ASSERT_RT_ERR(num_out_act_bound <= 0, "cpu algo don't support out bound")
out_inds = allocator.empty({pccm.literal(AllocKeys.OutIndices)},
{{kv * indices.dim(0), indices.dim(1)}}, indices.dtype(), -1);
generate_conv_inds_cpu(indices, pair, out_inds, indice_num_per_loc,
num_act_out = generate_conv_inds_cpu(indices, pair, out_inds, indice_num_per_loc,
batch_size, out_shape, input_dims, ksize,
stride, padding, dilation, transposed);
}}
......@@ -1422,9 +1560,13 @@ class SpconvOps(pccm.Class):
auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
auto indice_pairs_uniq_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, indice_uniq_dtype, 0);
auto indice_pairs_uniq_guard = allocator.empty_guard(
{{int64_t(pair.numel() / 2 + 1)}}, indice_uniq_dtype, 0,
{pccm.literal(AllocKeys.IndicePairsUniq)});
auto indice_pairs_uniq = indice_pairs_uniq_guard->tensor;
auto indice_pairs_uniq_bkp_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, indice_uniq_dtype, 0);
auto indice_pairs_uniq_bkp_guard = allocator.empty_guard(
{{int64_t(pair.numel() / 2 + 1)}}, indice_uniq_dtype, 0,
{pccm.literal(AllocKeys.IndicePairsUniqBackup)});
generate_conv_inds_stage1(indices, pair, indice_pairs_uniq,
indice_num_per_loc, batch_size, out_shape, input_dims, ksize,
......@@ -1432,27 +1574,35 @@ class SpconvOps(pccm.Class):
indice_pairs_uniq_bkp_guard->tensor.copy_(indice_pairs_uniq, tvctx);
// TODO pytorch unique may be faster?
int num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int) - 1;
num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int);
bool use_bound_algo = false;
if (num_out_act_bound > 0 && num_act_out > num_out_act_bound){{
num_act_out = num_out_act_bound;
use_bound_algo = true;
}}
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_act_out);
out_inds = allocator.empty({pccm.literal(SpconvAllocatorKeys.OutIndices)},
out_inds = allocator.empty({pccm.literal(AllocKeys.OutIndices)},
{{num_act_out, indices.dim(1)}}, indices.dtype(), 0);
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{num_act_out * 2}}, tv::int64, 0);
hash_v_gurad = allocator.empty_guard({{num_act_out * 2}}, tv::int32, 0);
hash_k_guard = allocator.empty_guard({{num_act_out * 2}},
tv::int64, 0, {pccm.literal(AllocKeys.HashKOrKV)});
hash_v_gurad = allocator.empty_guard({{num_act_out * 2}},
tv::int32, 0, {pccm.literal(AllocKeys.HashV)});
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_act_out * 2}}, tv::int32, 0);
hash_kv_gurad = allocator.empty_guard({{2, num_act_out * 2}},
tv::int32, 0, {pccm.literal(AllocKeys.HashKOrKV)});
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
generate_conv_inds_stage2(indices, hash_k, hash_v, pair,
num_act_out = generate_conv_inds_stage2(indices, hash_k, hash_v, pair,
indice_pairs_uniq, indice_pairs_uniq_bkp_guard->tensor,
out_inds, num_act_out,
out_inds, indice_num_per_loc, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int);
transposed, stream_int, use_bound_algo);
}}
""")
else:
......@@ -1462,18 +1612,6 @@ class SpconvOps(pccm.Class):
}}
""")
code.raw(f"""
return;
return num_act_out;
""")
return code
@pccm.pybind.mark
@pccm.static_function
def test_allocator(self):
code = pccm.code()
code.arg("allocator", "ExternalAllocator&")
code.raw(f"""
auto guard = allocator.zeros_guard({{1, 2, 3}}, tv::int32, 0);
tv::ssprint("????");
""")
return code
\ No newline at end of file
return code.ret("int")
import pccm
from cumm.common import TensorView, TensorViewCPU, TensorViewKernel, ThrustLib
from spconv.constants import AllocKeys
class ExternalAllocatorGuard(pccm.Class):
def __init__(self):
super().__init__()
......@@ -51,6 +53,9 @@ class ExternalAllocator(pccm.Class):
code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int")
code.arg("device", "int")
code.arg("is_temp_memory", "bool", "false")
code.arg("stream", "std::uintptr_t", "0")
return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True)
......@@ -61,6 +66,9 @@ class ExternalAllocator(pccm.Class):
code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int")
code.arg("device", "int")
code.arg("is_temp_memory", "bool", "false")
code.arg("stream", "std::uintptr_t", "0")
return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True)
......@@ -72,6 +80,9 @@ class ExternalAllocator(pccm.Class):
code.arg("value", "int")
code.arg("dtype", "int")
code.arg("device", "int")
code.arg("is_temp_memory", "bool", "false")
code.arg("stream", "std::uintptr_t", "0")
return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True)
......@@ -83,6 +94,15 @@ class ExternalAllocator(pccm.Class):
code.arg("value", "float")
code.arg("dtype", "int")
code.arg("device", "int")
code.arg("is_temp_memory", "bool", "false")
code.arg("stream", "std::uintptr_t", "0")
return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True, pure_virtual=True)
def get_tensor_by_name(self):
code = pccm.code()
code.arg("name", "std::string")
return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True)
......@@ -105,9 +125,11 @@ class ExternalAllocator(pccm.Class):
code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int")
code.arg("device", "int")
code.arg("name", "std::string", "\"\"")
code.arg("stream", "std::uintptr_t", "0")
code.raw(f"""
// "" means temp memory
auto ten = zeros("", shape, dtype, device);
auto ten = zeros(name, shape, dtype, device, true, stream);
return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
......@@ -120,8 +142,10 @@ class ExternalAllocator(pccm.Class):
code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int")
code.arg("device", "int")
code.arg("name", "std::string", "\"\"")
code.arg("stream", "std::uintptr_t", "0")
code.raw(f"""
auto ten = empty("", shape, dtype, device);
auto ten = empty(name, shape, dtype, device, true, stream);
return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
......@@ -135,8 +159,10 @@ class ExternalAllocator(pccm.Class):
code.arg("value", "int")
code.arg("dtype", "int")
code.arg("device", "int")
code.arg("name", "std::string", "\"\"")
code.arg("stream", "std::uintptr_t", "0")
code.raw(f"""
auto ten = full_int("", shape, value, dtype, device);
auto ten = full_int(name, shape, value, dtype, device, true, stream);
return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
......@@ -150,8 +176,10 @@ class ExternalAllocator(pccm.Class):
code.arg("value", "int")
code.arg("dtype", "int")
code.arg("device", "int")
code.arg("name", "std::string", "\"\"")
code.arg("stream", "std::uintptr_t", "0")
code.raw(f"""
auto ten = full_float("", shape, value, dtype, device);
auto ten = full_float(name, shape, value, dtype, device, true, stream);
return std::make_{self.ptr_type}<ExternalAllocatorGuard>(ten, [this](tv::Tensor t){{
this->free(t);
}});
......@@ -179,7 +207,7 @@ class ThrustAllocator(pccm.Class):
code.arg("num_bytes", "std::ptrdiff_t")
code.ret("char*")
code.raw(f"""
auto ten = allocator_.empty("", {{num_bytes}}, tv::uint8, 0);
auto ten = allocator_.empty({pccm.literal(AllocKeys.ThrustTemp)}, {{num_bytes}}, tv::uint8, 0);
return reinterpret_cast<char*>(ten.raw_data());
""")
return code
......@@ -193,3 +221,158 @@ class ThrustAllocator(pccm.Class):
return allocator_.free_noexcept(tv::from_blob(ptr, {{num_bytes}}, tv::uint8, 0));
""")
return code
class StaticAllocator(ExternalAllocator):
"""a simple allocator for tensorrt plugin.
"""
def __init__(self):
super().__init__()
self.add_dependency(TensorView)
self.add_member("tensor_dict_", "std::unordered_map<std::string, tv::Tensor>")
self.add_member("repr_", "std::string")
self.add_member("thrust_tmp_tensor_", "tv::Tensor")
self.grow = 1.5
@pccm.pybind.mark
@pccm.constructor
def ctor(self):
code = pccm.code()
code.arg("tensor_dict", "std::unordered_map<std::string, tv::Tensor>")
code.ctor_init("tensor_dict_", "tensor_dict")
code.raw(f"""
std::stringstream ss;
for (auto& p : tensor_dict){{
tv::ssprint(ss, p.first, p.second.shape(), tv::dtype_str(p.second.dtype()), "\\n");
}}
repr_ = ss.str();
""")
return code
@pccm.member_function(virtual=True)
def _get_raw_and_check(self):
code = pccm.code()
code.arg("name", "std::string")
code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int")
code.arg("device", "int")
code.raw(f"""
auto res = get_tensor_by_name(name);
size_t total = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
TV_ASSERT_RT_ERR(res.nbytes() >= total * tv::bit_size(tv::DType(dtype))
&& res.device() == device, "alloc failed", shape, res.shape());
return tv::from_blob(res.raw_data(), shape, dtype, device);
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.member_function(virtual=True)
def zeros(self):
code = pccm.code()
code.arg("name", "std::string")
code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int")
code.arg("device", "int")
code.arg("is_temp_memory", "bool", "false")
code.arg("stream", "std::uintptr_t", "0")
code.raw(f"""
auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream));
auto blob = _get_raw_and_check(name, shape, dtype, device);
return blob.zero_(tvctx);
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.member_function(virtual=True)
def empty(self):
code = pccm.code()
code.arg("name", "std::string")
code.arg("shape", "std::vector<int64_t>")
code.arg("dtype", "int")
code.arg("device", "int")
code.arg("is_temp_memory", "bool", "false")
code.arg("stream", "std::uintptr_t", "0")
code.raw(f"""
if (name == {pccm.literal(AllocKeys.ThrustTemp)}){{
// thrust tmp shouldn't inside tensor_dict. use a simple method to allocate
// we assume each allocator always handle one stream
// so we can just use one tensor
tv::Tensor res = thrust_tmp_tensor_;
if (res.empty()){{
res = tv::empty(shape, dtype, device);
thrust_tmp_tensor_ = res;
}}
if (shape[0] > thrust_tmp_tensor_.dim(0)){{
res = tv::empty({{int64_t(shape[0] * {self.grow})}}, dtype, device);
thrust_tmp_tensor_ = res;
}}
return res;
}}else{{
auto blob = _get_raw_and_check(name, shape, dtype, device);
return blob;
}}
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.member_function(virtual=True)
def full_int(self):
code = pccm.code()
code.arg("name", "std::string")
code.arg("shape", "std::vector<int64_t>")
code.arg("value", "int")
code.arg("dtype", "int")
code.arg("device", "int")
code.arg("is_temp_memory", "bool", "false")
code.arg("stream", "std::uintptr_t", "0")
code.raw(f"""
auto tvctx = tv::Context();
auto blob = _get_raw_and_check(name, shape, dtype, device);
return blob.fill_(tvctx, value);
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.member_function(virtual=True)
def full_float(self):
code = pccm.code()
code.arg("name", "std::string")
code.arg("shape", "std::vector<int64_t>")
code.arg("value", "float")
code.arg("dtype", "int")
code.arg("device", "int")
code.arg("is_temp_memory", "bool", "false")
code.arg("stream", "std::uintptr_t", "0")
code.raw(f"""
auto blob = _get_raw_and_check(name, shape, dtype, device);
return blob.fill_(tvctx, value);
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.member_function(virtual=True)
def get_tensor_by_name(self):
code = pccm.code()
code.arg("name", "std::string")
code.raw(f"""
TV_ASSERT_RT_ERR(tensor_dict_.find(name) != tensor_dict_.end(), "can't find", name, "exists:\\n", repr_);
return tensor_dict_.at(name);
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark
@pccm.member_function(virtual=True)
def free(self):
code = pccm.code()
code.arg("ten", "tv::Tensor")
return code
@pccm.pybind.mark
@pccm.member_function(virtual=True)
def free_noexcept(self):
code = pccm.code()
code.arg("ten", "tv::Tensor")
return code
from typing import Optional
import pccm
from cumm.gemm.main import GemmMainUnitTest
from cumm.common import GemmBasicHost, NlohmannJson, TensorView
from cumm.constants import CUMM_CPU_ONLY_BUILD
from cumm.conv.main import ConvMainUnitTest
from cumm.gemm.algospec.core import (_GEMM_MIN_ARCH_TO_ALGO, GemmAlgo,
ShuffleStrideType,
get_available_algo_str_from_arch,
get_min_arch_of_algo_str)
from cumm.gemm.main import GemmMainUnitTest
from spconv.constants import NDIM_DONT_CARE, SPCONV_BWD_SPLITK, AllocKeys
from spconv.core import AlgoHint, ConvAlgo
from spconv.csrc.sparse.gather import GatherCPU
from .alloc import ExternalAllocator
from spconv.core import ConvAlgo
from spconv.constants import SpconvAllocatorKeys
from cumm.constants import CUMM_CPU_ONLY_BUILD
from cumm.common import GemmBasicHost, TensorView, NlohmannJson
from cumm.common import CompileInfo
class ExternalSpconvMatmul(pccm.Class):
"""a helper class to warp matmul operations
because we don't want to implement matmul
(link to cublas/mkl/pytorch) in python package.
"""
def __init__(self):
super().__init__()
self.add_dependency(TensorView)
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True)
def indice_conv_init_gemm(self):
code = pccm.code()
code.arg("features_n, filters_n", "std::string")
code.arg("all_weight_is_krsc, is_kc_not_ck", "bool")
code.arg("kv_center, out_channel", "int")
code.arg("stream_int", "std::uintptr_t", "0")
code.raw(f"""
TV_THROW_RT_ERR("not implemented, override this and use preferred blas!!!");
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True)
def indice_conv_cpu_gemm(self):
code = pccm.code()
code.arg("inp_buffer_n, out_buffer_n, filters_n", "std::string")
code.arg("all_weight_is_krsc, is_kc_not_ck", "bool")
code.arg("nhot, index", "int")
code.raw(f"""
TV_THROW_RT_ERR("not implemented, override this and use preferred cpu blas!!!");
""")
return code
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True)
def indice_conv_bwd_init_gemm(self):
code = pccm.code()
code.arg("features_n, filters_n, out_bp_n, dfilters_n", "std::string")
code.arg("all_weight_is_krsc, is_kc_not_ck", "bool")
code.arg("kv_center", "int")
code.arg("stream_int", "std::uintptr_t", "0")
code.raw(f"""
TV_THROW_RT_ERR("not implemented, override this and use preferred blas!!!");
""")
return code.ret("tv::Tensor")
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True)
def indice_conv_bwd_cpu_gemm(self):
code = pccm.code()
code.arg("inp_buffer_n, out_buffer_n, filters_n, dfilters_n",
"std::string")
code.arg("all_weight_is_krsc, is_kc_not_ck", "bool")
code.arg("nhot, index", "int")
code.raw(f"""
TV_THROW_RT_ERR("not implemented, override this and use preferred cpu blas!!!");
""")
return code
class SimpleExternalSpconvMatmul(ExternalSpconvMatmul):
"""a helper class to warp matmul operations
because we don't want to implement matmul
(link to cublas/mkl/pytorch) in python package.
"""
def __init__(self):
super().__init__()
self.add_dependency(TensorView, ExternalAllocator)
self.build_meta.add_libraries("cublasLt")
self.add_include("cublasLt.h")
self.add_member("alloc_", "ExternalAllocator&")
self.add_member("handle_", "cublasLtHandle_t", "0")
@pccm.pybind.mark
@pccm.constructor
def ctor(self):
code = pccm.code()
code.arg("alloc", "ExternalAllocator&")
code.ctor_init("alloc_", "alloc")
code.raw(f"""
auto stat = cublasLtCreate(&handle_);
TV_ASSERT_RT_ERR(CUBLAS_STATUS_SUCCESS == stat, "err");
""")
return code
@pccm.destructor
def destructor(self):
code = pccm.code()
code.raw(f"""
if (handle_){{
cublasLtDestroy(handle_);
}}
""")
return code
@pccm.static_function
def check_cublas_status(self):
code = pccm.code()
code.arg("status", "cublasStatus_t")
code.raw(f"""
if (status != CUBLAS_STATUS_SUCCESS) {{
printf("cuBLAS API failed with status %d\\n", status);
throw std::logic_error("cuBLAS API failed");
}}
""")
return code
@pccm.static_function
def tv_dtype_to_blaslt(self):
code = pccm.code()
code.arg("dtype", "tv::DType")
code.raw(f"""
switch (dtype) {{
case tv::float32:
return CUDA_R_32F;
case tv::float16:
return CUDA_R_16F;
case tv::int32:
return CUDA_R_32I;
case tv::int8:
return CUDA_R_8I;
case tv::uint32:
return CUDA_R_32U;
default:
return CUDA_R_32F;
}}
""")
return code.ret("decltype(CUDA_R_16F)")
@pccm.static_function(inline=True)
def tv_dtype_to_compute(self):
code = pccm.code()
code.arg("dtype", "tv::DType")
with code.macro_if_("CUDART_VERSION >= 11000"):
code.raw(f"""
switch (dtype) {{
case tv::float32:
return CUBLAS_COMPUTE_32F;
case tv::float16:
return CUBLAS_COMPUTE_16F;
case tv::int32:
return CUBLAS_COMPUTE_32I;
case tv::int8:
return CUBLAS_COMPUTE_32F;
case tv::uint32:
return CUBLAS_COMPUTE_32F;
default:
return CUBLAS_COMPUTE_32F;
}}
""")
with code.macro_else_():
code.raw(f"""
switch (dtype) {{
case tv::float32:
return CUDA_R_32F;
case tv::float16:
return CUDA_R_16F;
case tv::int32:
return CUDA_R_32I;
case tv::int8:
return CUDA_R_8I;
case tv::uint32:
return CUDA_R_32U;
default:
return CUDA_R_32F;
}}
""")
code.macro_endif_()
return code.ret("decltype(auto)")
@pccm.static_function
def matmul_colmajor(self):
code = pccm.code()
code.arg("handle", "cublasLtHandle_t")
code.arg("stream", "cudaStream_t")
code.arg("a, b, c", "tv::Tensor")
code.arg("transA, transB", "bool")
code.raw(f"""
bool transC = false;
auto m = a.dim(int(!transA));
auto k = a.dim(int(transA));
auto k2 = b.dim(int(!transB));
auto n = b.dim(int(transB));
TV_ASSERT_INVALID_ARG(k == k2, "error");
TV_ASSERT_INVALID_ARG(a.dtype() == b.dtype(), "error");
tv::TensorShape c_shape;
if (transC) {{
c_shape = {{m, n}};
}} else {{
c_shape = {{n, m}};
}}
if (c.empty()) {{
c = tv::Tensor(c_shape, a.dtype(), a.device());
}} else {{
TV_ASSERT_INVALID_ARG(c.dim(0) == c_shape[0] && c.dim(1) == c_shape[1],
"error");
}}
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
decltype(CUDA_R_16F) scalarType = CUDA_R_16F;
#if CUDART_VERSION >= 11000
decltype(CUBLAS_COMPUTE_32F) computeType = CUBLAS_COMPUTE_32F;
#endif
if (a.dtype() == tv::float16 && b.dtype() == tv::float16 &&
c.dtype() == tv::float16) {{
scalarType = CUDA_R_16F;
#if CUDART_VERSION >= 11000
computeType = CUBLAS_COMPUTE_16F;
#endif
}} else if (a.dtype() == tv::float32 && b.dtype() == tv::float32 &&
c.dtype() == tv::float16) {{
scalarType = CUDA_R_32F;
#if CUDART_VERSION >= 11000
computeType = CUBLAS_COMPUTE_32F;
#endif
}} else if (a.dtype() == tv::float32 && b.dtype() == tv::float32 &&
c.dtype() == tv::float32) {{
scalarType = CUDA_R_32F;
#if CUDART_VERSION >= 11000
computeType = CUBLAS_COMPUTE_32F;
#endif
}} else if (a.dtype() == tv::float16 && b.dtype() == tv::float16 &&
c.dtype() == tv::float32) {{
scalarType = CUDA_R_32F;
#if CUDART_VERSION >= 11000
computeType = CUBLAS_COMPUTE_32F;
#endif
}} else {{
TV_THROW_RT_ERR("unsupported");
}}
#if CUDART_VERSION >= 11000
check_cublas_status(
cublasLtMatmulDescCreate(&operationDesc, computeType, scalarType));
#else
check_cublas_status(cublasLtMatmulDescCreate(&operationDesc, scalarType));
#endif
cublasOperation_t transa = !transA ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t transb = !transB ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t transc = !transC ? CUBLAS_OP_N : CUBLAS_OP_T;
check_cublas_status(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)));
check_cublas_status(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb)));
// check_cublas_status(cublasLtMatmulDescSetAttribute(
// operationDesc, CUBLASLT_MATMUL_DESC_TRANSC, &transc,
// sizeof(transc)));
check_cublas_status(cublasLtMatrixLayoutCreate(
&Adesc, tv_dtype_to_blaslt(a.dtype()), transa == CUBLAS_OP_N ? m : k,
transa == CUBLAS_OP_N ? k : m, a.stride(0)));
check_cublas_status(cublasLtMatrixLayoutCreate(
&Bdesc, tv_dtype_to_blaslt(b.dtype()), transb == CUBLAS_OP_N ? k : n,
transb == CUBLAS_OP_N ? n : k, b.stride(0)));
// check_cublas_status(cublasLtMatrixLayoutCreate(
// &Cdesc, tv_dtype_to_blaslt(c.dtype()), transc == CUBLAS_OP_N ? m : n,
// transc == CUBLAS_OP_N ? n : m, c.dim(0)));
check_cublas_status(cublasLtMatrixLayoutCreate(
&Cdesc, tv_dtype_to_blaslt(c.dtype()), m, n, c.stride(0)));
cublasLtMatmulHeuristicResult_t heuristicResult = {{}};
cublasLtMatmulPreference_t preference = NULL;
check_cublas_status(cublasLtMatmulPreferenceCreate(&preference));
size_t workspaceSize = 0;
check_cublas_status(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize,
sizeof(workspaceSize)));
int returnedResults = 0;
check_cublas_status(cublasLtMatmulAlgoGetHeuristic(
handle, operationDesc, Adesc, Bdesc, Cdesc, Cdesc, preference, 1,
&heuristicResult, &returnedResults));
if (returnedResults == 0) {{
check_cublas_status(CUBLAS_STATUS_NOT_SUPPORTED);
}}
int alpha_storage[4];
int beta_storage[4];
if (scalarType == CUDA_R_32F) {{
*(reinterpret_cast<float *>(alpha_storage)) = 1.0f;
*(reinterpret_cast<float *>(beta_storage)) = 0.0f;
}} else if (scalarType == CUDA_R_16F) {{
*(reinterpret_cast<__half *>(alpha_storage)) = __half(1.0f);
*(reinterpret_cast<__half *>(beta_storage)) = __half(0.0f);
}} else {{
TV_THROW_RT_ERR("unsupported");
}}
check_cublas_status(cublasLtMatmul(
handle, operationDesc, alpha_storage, a.raw_data(), Adesc, b.raw_data(),
Bdesc, beta_storage, c.raw_data(), Cdesc, c.raw_data(), Cdesc,
&heuristicResult.algo, nullptr, 0, stream));
if (preference)
check_cublas_status(cublasLtMatmulPreferenceDestroy(preference));
if (Cdesc)
check_cublas_status(cublasLtMatrixLayoutDestroy(Cdesc));
if (Bdesc)
check_cublas_status(cublasLtMatrixLayoutDestroy(Bdesc));
if (Adesc)
check_cublas_status(cublasLtMatrixLayoutDestroy(Adesc));
if (operationDesc)
check_cublas_status(cublasLtMatmulDescDestroy(operationDesc));
return;
""")
return code
@pccm.static_function
def matmul(self):
code = pccm.code()
code.arg("handle", "cublasLtHandle_t")
code.arg("stream", "cudaStream_t")
code.arg("a, b, c", "tv::Tensor")
code.arg("transA, transB", "bool")
code.raw(f"""
return matmul_colmajor(handle, stream, b, a, c, transB, transA);
""")
return code
@pccm.member_function
def indice_conv_init_gemm(self):
code = pccm.code()
code.arg("features_n, filters_n", "std::string")
code.arg("all_weight_is_krsc, is_kc_not_ck", "bool")
code.arg("kv_center, out_channel", "int")
code.arg("stream_int", "std::uintptr_t")
code.raw(f"""
auto features = alloc_.get_tensor_by_name(features_n);
auto filters = alloc_.get_tensor_by_name(filters_n);
TV_ASSERT_RT_ERR(!features.is_cpu(), "only supprt cuda");
auto out_features = alloc_.empty({pccm.literal(AllocKeys.OutFeatures)},
{{features.dim(0), out_channel}}, features.dtype(), features.device());
if (!all_weight_is_krsc){{
filters = filters.view(-1, filters.dim(-2), filters.dim(-1));
if (!is_kc_not_ck){{
matmul(handle_, reinterpret_cast<cudaStream_t>(stream_int),
features, filters[kv_center], out_features, false, false);
}}else{{
matmul(handle_, reinterpret_cast<cudaStream_t>(stream_int),
features, filters[kv_center], out_features, false, true);
}}
}}else{{
filters = filters.view(out_channel, -1, filters.dim(-1));
matmul(handle_, reinterpret_cast<cudaStream_t>(stream_int),
features, filters.select(1, kv_center), out_features, false, true);
}}
return out_features;
""")
return code.ret("tv::Tensor")
class GemmTuneResult(pccm.Class, pccm.pybind.PybindClassMixin):
......@@ -21,8 +388,8 @@ class GemmTuneResult(pccm.Class, pccm.pybind.PybindClassMixin):
@pccm.member_function
def is_valid(self):
code = pccm.code()
code.raw(f"return splitk > 0 && std::get<0>(arch) > 0")
return code
code.raw(f"return splitk > 0 && std::get<0>(arch) > 0;")
return code.ret("bool")
@pccm.pybind.mark
@pccm.constructor
......@@ -61,7 +428,10 @@ class ConvTuneResult(pccm.Class, pccm.pybind.PybindClassMixin):
@pccm.constructor
def defaultctor(self):
code = pccm.code()
code.ctor_init("algo_desp", "tv::gemm::ConvAlgoDesp()")
code.ctor_init(
"algo_desp",
f"tv::gemm::ConvAlgoDesp({NDIM_DONT_CARE}, tv::gemm::ConvOpType::kForward)"
)
code.ctor_init("arch", "std::make_tuple(-1, -1)")
code.ctor_init("splitk", "-1")
return code
......@@ -84,124 +454,1738 @@ class ConvTuneResult(pccm.Class, pccm.pybind.PybindClassMixin):
@pccm.member_function
def is_valid(self):
code = pccm.code()
code.raw(f"return splitk > 0 && std::get<0>(arch) > 0")
return code
code.raw(f"return splitk > 0 && std::get<0>(arch) > 0;")
return code.ret("bool")
class GemmTunerSimple(pccm.ParameterizedClass):
def __init__(self, gemm_cu: GemmMainUnitTest, conv_cu: ConvMainUnitTest):
def __init__(self, gemm_cu: Optional[GemmMainUnitTest]):
super().__init__()
self.add_dependency(ExternalAllocator, GemmTuneResult,
ConvTuneResult, TensorView)
self.add_dependency(ExternalAllocator, GemmTuneResult, TensorView,
GemmBasicHost, CompileInfo)
if gemm_cu is not None:
self.add_param_class("gemm", gemm_cu, "GemmMain")
self.add_param_class("conv", conv_cu, "ConvMain")
if not CUMM_CPU_ONLY_BUILD:
assert gemm_cu is not None
self.add_include("tensorview/profile/cuda_profiler.h")
self.add_include("tensorview/utility/tuplehash.h")
self.add_include("mutex")
self.add_typedef(
"static_key_t", "std::tuple<bool, bool, bool, int, "
"int, int, int, std::string>")
self.add_typedef("algo_cache_key_t", "std::tuple<int, "
"int, int, int, int>")
self.add_member("desps_", "std::vector<tv::gemm::GemmAlgoDesp>")
self.add_member("nvrtc_progs_", "std::unordered_map<std::string, tv::NVRTCProgram>")
self.add_member("nvrtc_caches_", "std::unordered_map<std::tuple<std::string, int, int, std::uintptr_t>, tv::NVRTCModule>")
self.add_member(
"static_key_to_desps_",
"std::unordered_map<static_key_t, std::vector<tv::gemm::GemmAlgoDesp>>"
)
self.add_member("prebuilt_names_", "std::unordered_set<std::string>")
self.add_member("mutex_", "std::mutex")
self.add_member(
"nk_forward_cache_, nk_dgrad_cache_, mn_cache_",
"std::unordered_map<algo_cache_key_t, GemmTuneResult>")
@pccm.pybind.mark
@pccm.constructor
def ctor(self):
code = pccm.code()
code.arg("desps", "std::vector<tv::gemm::GemmAlgoDesp>")
code.arg("nvrtc_progs", "std::unordered_map<std::string, std::string>")
code.ctor_init("desps_", "desps")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code
code.raw(f"""
for (auto& v : nvrtc_progs){{
const uint8_t* code_ptr = reinterpret_cast<const uint8_t*>(v.second.c_str());
nvrtc_progs_.insert(v.first, tv::NVRTCProgram::from_binary(code_ptr, v.second.size()));
for (auto& d : desps){{
static_key_t static_key = std::make_tuple(d.trans_a(), d.trans_b(), d.trans_c(), d.dtype_a, d.dtype_b,
d.dtype_c, int(d.shuffle_type), d.algo);
auto& vec = static_key_to_desps_[static_key];
vec.push_back(d);
}}
for (auto desp : GemmMain::get_all_algo_desp()){{
prebuilt_names_.insert(desp.__repr__());
}}
""")
return code
@pccm.pybind.mark
@pccm.static_function
def get_available_algo_str_from_arch(self):
code = pccm.code()
code.arg("arch", "std::tuple<int, int>")
code.raw(f"""
std::vector<std::string> res;
""")
for i in range(len(_GEMM_MIN_ARCH_TO_ALGO) - 1, -1, -1):
arch_cur, algos = _GEMM_MIN_ARCH_TO_ALGO[i]
code.raw(f"""
auto arch_cur_{i} = std::make_tuple(int({arch_cur[0]}), int({arch_cur[1]}));
""")
with code.if_(f"arch >= arch_cur_{i}"):
for algo in algos:
code.raw(f"""
res.push_back({pccm.literal(algo)});
""")
code.raw(f"return res;")
return code.ret("std::vector<std::string>")
@pccm.pybind.mark
@pccm.member_function
def get_all_available(self):
code = pccm.code()
code.arg("a, b, c", "tv::Tensor")
code.arg("trans_a, trans_b, trans_c", "bool")
code.arg("arch", "std::tuple<int, int>")
code.arg("nvrtc_progs", "std::unordered_map<std::string, std::string>")
code.ctor_init("desps_", "desps")
code.arg("shuffle_type", "int")
code.raw(f"""
for (auto& v : nvrtc_progs){{
const uint8_t* code_ptr = reinterpret_cast<const uint8_t*>(v.second.c_str());
nvrtc_progs_.insert(v.first, tv::NVRTCProgram::from_binary(code_ptr, v.second.size()));
if (trans_c){{
trans_a = !trans_a;
trans_b = !trans_b;
std::swap(trans_a, trans_b);
std::swap(a, b);
trans_c = false;
}}
auto avail_algos = get_available_algo_str_from_arch(arch);
std::vector<tv::gemm::GemmAlgoDesp> finally_algos;
auto is_arch_compiled = CompileInfo::arch_is_compiled(arch);
for (auto algo : avail_algos){{
static_key_t static_key = std::make_tuple(trans_a, trans_b, trans_c, int(a.dtype()),
int(b.dtype()), int(c.dtype()), shuffle_type, algo);
if (static_key_to_desps_.find(static_key) == static_key_to_desps_.end()){{
continue;
}}
auto& desps = static_key_to_desps_.at(static_key);
for (auto& desp : desps){{
if (arch >= std::make_tuple(7, 5) && desp.algo == {pccm.literal(GemmAlgo.Volta.value)}){{
continue;
}}
auto lda = a.stride(0);
auto ldb = b.stride(0);
auto ldc = c.stride(0);
if (desp.supported_ldx(lda, ldb, ldc)){{
if (!is_arch_compiled){{
auto desp2 = desp;
desp2.is_nvrtc = true;
finally_algos.push_back(desp2);
}}else{{
finally_algos.push_back(desp);
}}
}}
}}
}}
return finally_algos;
""")
return code
return code.ret("std::vector<tv::gemm::GemmAlgoDesp>",
pyanno="List[cumm.tensorview.gemm.GemmAlgoDesp]")
@pccm.member_function
def extract_mnk(self):
code = pccm.code()
code.arg("a_shape, b_shape", "tv::TensorShape")
code.arg("trans_a, trans_b, trans_c", "bool")
code.arg("arch", "std::tuple<int, int>")
code.arg("shuffle_type", "int")
code.arg("a_inds_shape, b_inds_shape, c_inds_shape", "tv::TensorShape")
code.arg("hint", "int", f"{AlgoHint.NoHint.value}")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code.ret("std::tuple<int, int, int>")
code.raw(f"""
std::vector<int64_t> a_shape_vec(a_shape.begin(), a_shape.end());
std::vector<int64_t> b_shape_vec(b_shape.begin(), b_shape.end());
std::vector<int64_t> a_inds_shape_vec(a_inds_shape.begin(), a_inds_shape.end());
std::vector<int64_t> b_inds_shape_vec(b_inds_shape.begin(), b_inds_shape.end());
std::vector<int64_t> c_inds_shape_vec(c_inds_shape.begin(), c_inds_shape.end());
class ConvGemmOps(pccm.ParameterizedClass):
return GemmMain::extract_mnk(a_shape_vec, b_shape_vec, trans_a,
trans_b, trans_c,
shuffle_type,
a_inds_shape_vec, b_inds_shape_vec,
c_inds_shape_vec);
""")
return code.ret("std::tuple<int, int, int>")
def __init__(self, gemm_cu: GemmMainUnitTest, conv_cu: ConvMainUnitTest):
super().__init__()
self.add_dependency(ExternalAllocator, GemmTuneResult,
ConvTuneResult)
self.add_param_class("gemm", gemm_cu, "GemmMain")
self.add_param_class("conv", conv_cu, "ConvMain")
@pccm.static_function
def extract_mnk_vector(self):
code = pccm.code()
code.arg("a_shape, b_shape", "std::vector<int64_t>")
code.arg("trans_a, trans_b, trans_c", "bool")
code.arg("shuffle_type", "int")
code.arg("a_inds_shape, b_inds_shape, c_inds_shape",
"std::vector<int64_t>")
code.raw(f"""
return GemmMain::extract_mnk(a_shape, b_shape, trans_a,
trans_b, trans_c,
shuffle_type,
a_inds_shape, b_inds_shape,
c_inds_shape);
""")
return code.ret("std::tuple<int, int, int>")
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True)
def cached_get_nvrtc_params(self):
code = pccm.code()
code.arg("desp",
"tv::gemm::GemmAlgoDesp",
pyanno="cumm.tensorview.gemm.GemmAlgoDesp")
code.arg("arch", "std::tuple<int, int>")
code.arg("stream_int", "std::uintptr_t")
code.raw(f"""
TV_THROW_RT_ERR("not implemented in c++, must be overrided in python!!!");
""")
return code.ret("tv::gemm::NVRTCParams",
pyanno="cumm.tensorview.gemm.NVRTCParams")
@pccm.pybind.mark
@pccm.static_function
def indice_conv(self):
"""1. this function need to take a out features
that from subm first mm.
2. this function don't support CPU.
"""
@pccm.member_function
def tune_and_cache(self):
code = pccm.code()
code.arg("allocator", "ExternalAllocator&")
code.arg("out_features_after_mm", "tv::Tensor")
code.arg("features, filters, indice_pairs", "tv::Tensor")
code.arg("indice_pair_num", "tv::Tensor")
code.arg("num_activate_out", "int")
code.arg("inverse", "bool", "false")
code.arg("subm", "bool", "false")
code.arg("algo", "int", f"{ConvAlgo.Native.value}")
code.arg("filter_hwio", "bool", "false")
code.arg("a, b, c", "tv::Tensor")
code.arg("trans_a, trans_b, trans_c", "bool")
code.arg("arch", "std::tuple<int, int>")
code.arg("shuffle_type", "int")
code.arg("a_inds, b_inds, c_inds", "tv::Tensor")
code.arg("hint", "int", f"{AlgoHint.NoHint.value}")
code.arg("alpha", "float", "1.0")
code.arg("beta", "float", "0.0")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("num_run", "int", "5")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
code.raw("return std::make_tuple(GemmTuneResult(), -1.0f);")
return code.ret(
"std::tuple<GemmTuneResult, float>",
pyanno=
"Tuple[spconv.core_cc.csrc.sparse.convops.GemmTuneResult, float]"
)
code.raw(f"""
throw std::runtime_error("this function can only be used with CUDA.")
TV_ASSERT_RT_ERR(num_run > 1, "error");
auto mnk = extract_mnk(a.shape(), b.shape(), trans_a,
trans_b, trans_c,
arch,
shuffle_type,
a_inds.shape(), b_inds.shape(),
c_inds.shape());
auto m = std::get<0>(mnk);
auto n = std::get<1>(mnk);
auto k = std::get<2>(mnk);
auto avail = get_all_available(a, b, c, trans_a,
trans_b, trans_c, arch, shuffle_type);
auto c_ = c.clone_whole_storage();
std::vector<GemmTuneResult> all_profile_res;
std::vector<int> splitk_tests;
std::vector<float> times;
for (auto& desp : avail){{
tv::gemm::GemmParams params;
if (desp.is_nvrtc && prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end()){{
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
}}
params.a = a;
params.b = b;
params.c = c_;
params.a_inds = a_inds;
params.b_inds = b_inds;
params.c_inds = c_inds;
params.algo_desp = desp;
params.alpha = alpha;
params.beta = beta;
params.stream = stream_int;
if (desp.split_k_serial() && (hint & {AlgoHint.BackwardWeight.value})){{
splitk_tests = {{{', '.join(map(str, SPCONV_BWD_SPLITK))}}};
}} else {{
splitk_tests = {{1}};
}}
for (auto spk : splitk_tests){{
float total_time = 0.0;
params.split_k_slices = spk;
for (int j = 0; j < num_run; ++j){{
auto ev_start = tv::CUDAEvent();
auto ev_end = tv::CUDAEvent();
ev_start.record(stream_int);
GemmMain::matmul2(params);
ev_end.record(stream_int);
if (j > 0){{
// skip first run
total_time += tv::CUDAEvent::sync_and_duration(ev_start, ev_end);
}}
}}
total_time /= (num_run - 1);
times.push_back(total_time);
all_profile_res.push_back(GemmTuneResult(desp, arch, spk));
}}
}}
TV_ASSERT_RT_ERR(!all_profile_res.empty(), "can't find suitable algorithm");
auto min_idx = std::min_element(times.begin(), times.end()) - times.begin();
auto min_tune_res = all_profile_res[min_idx];
{{
std::lock_guard<std::mutex> guard(mutex_);
algo_cache_key_t key;
if (hint & {AlgoHint.BackwardWeight.value}){{
key = std::make_tuple(int(a.dtype()), int(b.dtype()), int(c.dtype()), m, n);
mn_cache_[key] = min_tune_res;
}}
else if (hint & {AlgoHint.BackwardInput.value}){{
key = std::make_tuple(int(a.dtype()), int(b.dtype()), int(c.dtype()), n, k);
nk_dgrad_cache_[key] = min_tune_res;
}}
else if (hint & {AlgoHint.Fowrard.value}){{
key = std::make_tuple(int(a.dtype()), int(b.dtype()), int(c.dtype()), n, k);
nk_forward_cache_[key] = min_tune_res;
}}
else{{
TV_THROW_RT_ERR("not implemented");
}}
}}
return std::make_tuple(min_tune_res, times[min_idx]);
""")
return code.ret("tv::Tensor")
return code.ret(
"std::tuple<GemmTuneResult, float>",
pyanno=
"Tuple[spconv.core_cc.csrc.sparse.convops.GemmTuneResult, float]")
@pccm.pybind.mark
@pccm.member_function
def get_tuned_algo(self):
code = pccm.code()
code.arg("a_dtype, b_dtype, c_dtype", "int")
code.arg("a_shape, b_shape, c_shape", "std::vector<int64_t>")
code.arg("trans_a, trans_b, trans_c", "bool")
code.arg("arch", "std::tuple<int, int>")
code.arg("shuffle_type", "int")
code.arg("a_inds_shape, b_inds_shape, c_inds_shape",
"std::vector<int64_t>")
code.arg("hint", "int", f"{AlgoHint.NoHint.value}")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
code.raw("return std::make_tuple(GemmTuneResult(), false);")
return code.ret("std::tuple<GemmTuneResult, bool>")
code.raw(f"""
TV_ASSERT_RT_ERR(!features.is_cpu(), "this function don't support cpu.")
int out_channel;
if (filter_hwio){{
out_channel = filters.dim(-1);
}}else{{
out_channel = filters.dim(-2);
auto mnk = GemmMain::extract_mnk(a_shape, b_shape, trans_a,
trans_b, trans_c,
shuffle_type,
a_inds_shape, b_inds_shape,
c_inds_shape);
auto m = std::get<0>(mnk);
auto n = std::get<1>(mnk);
auto k = std::get<2>(mnk);
GemmTuneResult res;
bool exists = false;
{{
std::lock_guard<std::mutex> guard(mutex_);
algo_cache_key_t key;
if (hint & {AlgoHint.BackwardWeight.value}){{
key = std::make_tuple(int(a_dtype), int(b_dtype), int(c_dtype), m, n);
if (mn_cache_.find(key) != mn_cache_.end()){{
res = mn_cache_.at(key);
exists = true;
}}
filters = filters.view(-1, filters.dim(-2), filters.dim(-1));
int kv = filters.dim(0);
int kv_center = kv / 2;
tv::Tensor out_features;
if (kv == 1 && subm){{
return;
}}
auto indice_pair_num_cpu = indice_pair_num.cpu();
auto indice_pair_num_cpu_ptr = indice_pair_num_cpu.data_ptr<int>();
int maxnhot = 0;
bool all_zero = true;
for (int i = 0; i < kv; ++i){{
if (indice_pair_num_cpu_ptr[i] != 0){{
all_zero = false;
maxnhot = std::max(maxnhot, indice_pair_num_cpu_ptr[i]);
else if (hint & {AlgoHint.BackwardInput.value}){{
key = std::make_tuple(int(a_dtype), int(b_dtype), int(c_dtype), n, k);
if (nk_dgrad_cache_.find(key) != nk_dgrad_cache_.end()){{
res = nk_dgrad_cache_.at(key);
exists = true;
}}
}}
if (subm && all_zero){{
return;
else if (hint & {AlgoHint.Fowrard.value}){{
key = std::make_tuple(int(a_dtype), int(b_dtype), int(c_dtype), n, k);
if (nk_forward_cache_.find(key) != nk_forward_cache_.end()){{
res = nk_forward_cache_.at(key);
exists = true;
}}
bool inited = subm;
auto a = features;
auto c = out_features;
auto pair_in = indice_pairs[int(inverse)];
auto pair_out = indice_pairs[int(!inverse)];
}}
else{{
TV_THROW_RT_ERR("not implemented");
}}
}}
return std::make_tuple(res, exists);
""")
return code.ret("std::tuple<GemmTuneResult, bool>")
@pccm.pybind.mark
@pccm.member_function
def run_with_tuned_result(self):
code = pccm.code()
code.arg("profile_res", "GemmTuneResult")
code.arg("a, b, c", "tv::Tensor")
code.arg("trans_a, trans_b, trans_c", "bool")
code.arg("arch", "std::tuple<int, int>")
code.arg("stream_int", f"std::uintptr_t")
code.arg("shuffle_type", "int")
code.arg("a_inds, b_inds, c_inds", "tv::Tensor")
code.arg("hint", "int", f"{AlgoHint.NoHint.value}")
code.arg("alpha", "float", "1.0")
code.arg("beta", "float", "0.0")
code.arg("workspace", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("timer", "tv::CUDAKernelTimer", "tv::CUDAKernelTimer(false)",
"cumm.tensorview.CUDAKernelTimer = CUDAKernelTimer(False)")
code.arg("force_nvrtc", f"bool", "false")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code
code.raw(f"""
auto& desp = profile_res.algo_desp;
int split_k_slices = 1;
if (profile_res.splitk > 1){{
split_k_slices = profile_res.splitk;
}}
tv::gemm::GemmParams params;
bool desp_is_static = prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end();
if (desp.is_nvrtc && (desp_is_static || force_nvrtc)){{
params.nvrtc_params = cached_get_nvrtc_params(desp, profile_res.arch, stream_int);
}}
params.a = a;
params.b = b;
params.c = c;
params.a_inds = a_inds;
params.b_inds = b_inds;
params.c_inds = c_inds;
params.algo_desp = desp;
params.split_k_slices = split_k_slices;
params.stream = stream_int;
params.alpha = alpha;
params.beta = beta;
params.workspace = workspace;
GemmMain::matmul2(params);
""")
return code
class ConvTunerSimple(pccm.ParameterizedClass):
def __init__(self, conv_cu: Optional[ConvMainUnitTest] = None):
super().__init__()
self.add_dependency(ExternalAllocator, ConvTuneResult, TensorView,
GemmBasicHost, CompileInfo)
if conv_cu is not None:
self.add_param_class("conv", conv_cu, "ConvMain")
if not CUMM_CPU_ONLY_BUILD:
assert conv_cu is not None
self.add_include("tensorview/profile/cuda_profiler.h")
self.add_include("tensorview/utility/tuplehash.h")
self.add_include("mutex")
self.add_typedef("static_key_t",
("std::tuple<int, int, int, int, int, "
"int, int, int, int, std::string, int>"))
self.add_typedef(
"algo_cache_key_t", "std::tuple<int, int, int, int, "
"int, int, int, int>")
self.add_member("desps_", "std::vector<tv::gemm::ConvAlgoDesp>")
self.add_member(
"static_key_to_desps_",
"std::unordered_map<static_key_t, std::vector<tv::gemm::ConvAlgoDesp>>"
)
self.add_member("prebuilt_names_", "std::unordered_set<std::string>")
self.add_member("mutex_", "std::mutex")
self.add_member(
"kc_forward_cache_, kc_dgrad_cache_, kc_wgrad_cache_",
"std::unordered_map<algo_cache_key_t, ConvTuneResult>")
@pccm.pybind.mark
@pccm.constructor
def ctor(self):
code = pccm.code()
code.arg("desps", "std::vector<tv::gemm::ConvAlgoDesp>")
code.ctor_init("desps_", "desps")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code
code.raw(f"""
for (auto& d : desps){{
static_key_t static_key = std::make_tuple(
int(d.layout_i), int(d.layout_w), int(d.layout_o),
d.interleave_i, d.interleave_w, d.interleave_o, d.dtype_input(),
d.dtype_weight(), d.dtype_output(), d.algo, int(d.op_type));
auto& vec = static_key_to_desps_[static_key];
vec.push_back(d);
}}
for (auto desp : ConvMain::get_all_conv_algo_desp()){{
prebuilt_names_.insert(desp.__repr__());
}}
""")
return code
@pccm.pybind.mark
@pccm.static_function
def get_available_algo_str_from_arch(self):
code = pccm.code()
code.arg("arch", "std::tuple<int, int>")
code.raw(f"""
std::vector<std::string> res;
""")
for i in range(len(_GEMM_MIN_ARCH_TO_ALGO) - 1, -1, -1):
arch_cur, algos = _GEMM_MIN_ARCH_TO_ALGO[i]
code.raw(f"""
auto arch_cur_{i} = std::make_tuple(int({arch_cur[0]}), int({arch_cur[1]}));
""")
with code.if_(f"arch >= arch_cur_{i}"):
for algo in algos:
code.raw(f"""
res.push_back({pccm.literal(algo)});
""")
code.raw(f"return res;")
return code.ret("std::vector<std::string>")
@pccm.pybind.mark
@pccm.member_function
def get_all_available(self):
code = pccm.code()
code.arg("inp, weight, out", "tv::Tensor")
code.arg("layout_i, layout_w, layout_o", "int")
code.arg("interleave_i, interleave_w, interleave_o", "int")
code.arg("arch", "std::tuple<int, int>")
code.arg("op_type", "int")
code.arg("mask_width", "int")
code.arg("auto_fp32_accum", "bool")
code.arg("fp32_accum", "bool")
code.raw(f"""
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
auto avail_algos = get_available_algo_str_from_arch(arch);
bool is_fp16 = (inp.dtype() == tv::float16 &&
weight.dtype() == tv::float16 && out.dtype() == tv::float16);
bool use_f32_as_accum = false;
int kv = 1;
for (int i = 0; i < weight.ndim() - 2; ++i){{
kv *= weight.dim(i + 1);
}}
if (is_fp16){{
if (auto_fp32_accum){{
if (op_type_cpp == tv::gemm::ConvOpType::kForward)
use_f32_as_accum = weight.dim(-1) * kv > 128 * 27;
else if (op_type_cpp == tv::gemm::ConvOpType::kBackwardInput)
use_f32_as_accum = weight.dim(0) * kv > 128 * 27;
}}else{{
use_f32_as_accum = fp32_accum;
}}
}}
use_f32_as_accum = false;
std::vector<tv::gemm::ConvAlgoDesp> finally_algos;
auto is_arch_compiled = CompileInfo::arch_is_compiled(arch);
for (auto algo : avail_algos){{
static_key_t static_key = std::make_tuple(
layout_i, layout_w, layout_o,
interleave_i, interleave_w, interleave_o, inp.dtype(),
weight.dtype(), out.dtype(), algo, op_type);
if (static_key_to_desps_.find(static_key) == static_key_to_desps_.end()){{
continue;
}}
auto& desps = static_key_to_desps_.at(static_key);
for (auto& desp : desps){{
if (arch >= std::make_tuple(7, 5) && desp.algo == {pccm.literal(GemmAlgo.Volta.value)}){{
continue;
}}
if (arch >= std::make_tuple(7, 0) && is_fp16){{
// skip simt fp16 kernels if we have tensor core
if (desp.algo == {pccm.literal(GemmAlgo.Simt.value)}){{
continue;
}}
if (use_f32_as_accum){{
if (desp.dacc == tv::float16){{
continue;
}}
}}
}}
int ldi = inp.dim(-1);
int ldw = weight.dim(-1);
int ldo = out.dim(-1);
bool mask_width_valid = true;
if (desp.op_type == tv::gemm::ConvOpType::kBackwardWeight){{
TV_ASSERT_RT_ERR(mask_width > 0, "eroro");
mask_width_valid = mask_width % desp.tile_shape[2] == 0;
}}
if (desp.supported_ldx_conv(ldi, ldw, ldo) && mask_width_valid){{
if (!is_arch_compiled){{
auto desp2 = desp;
desp2.is_nvrtc = true;
finally_algos.push_back(desp2);
}}else{{
finally_algos.push_back(desp);
}}
}}
}}
}}
return finally_algos;
""")
return code.ret("std::vector<tv::gemm::ConvAlgoDesp>",
pyanno="List[cumm.tensorview.gemm.ConvAlgoDesp]")
@pccm.pybind.mark(virtual=True)
@pccm.member_function(virtual=True)
def cached_get_nvrtc_params(self):
code = pccm.code()
code.arg("desp",
"tv::gemm::ConvAlgoDesp",
pyanno="cumm.tensorview.gemm.ConvAlgoDesp")
code.arg("arch", "std::tuple<int, int>")
code.arg("stream_int", "std::uintptr_t")
code.raw(f"""
TV_THROW_RT_ERR("not implemented in c++, must be overrided in python!!!");
""")
return code.ret("tv::gemm::NVRTCParams",
pyanno="cumm.tensorview.gemm.NVRTCParams")
@pccm.pybind.mark
@pccm.member_function
def tune_and_cache(self):
code = pccm.code()
code.arg("op_type", "int")
code.arg("inp, weight, output", "tv::Tensor")
code.arg("layout_i, layout_w, layout_o", "int")
code.arg("interleave_i, interleave_w, interleave_o", "int")
code.arg("arch", "std::tuple<int, int>")
code.arg("mask, mask_argsort, indices", "tv::Tensor")
code.arg("reverse_mask", "bool")
code.arg("mask_filter", "uint32_t", "0xffffffff")
code.arg("mask_width", "int", "-1")
code.arg("mask_output", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("alpha", "float", "1.0")
code.arg("beta", "float", "0.0")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("auto_fp32_accum", "bool", "true")
code.arg("fp32_accum", "bool", "false")
code.arg("num_run", "int", "5")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code.ret(
"std::tuple<ConvTuneResult, float>",
pyanno=
"Tuple[spconv.core_cc.csrc.sparse.convops.ConvTuneResult, float]"
)
code.raw(f"""
TV_ASSERT_RT_ERR(num_run > 1, "error");
auto avail = get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, interleave_i, interleave_w, interleave_o,
arch, op_type, mask_width,
auto_fp32_accum, fp32_accum);
inp = inp.clone();
weight = weight.clone();
output = output.clone();
int channel_k = output.dim(1);
int channel_c = inp.dim(1);
std::vector<ConvTuneResult> all_profile_res;
std::vector<int> splitk_tests;
std::vector<float> times;
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
for (auto& desp : avail){{
tv::gemm::ConvParams params({NDIM_DONT_CARE}, op_type_cpp, tv::CUDAKernelTimer(false));
if (desp.is_nvrtc && prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end()){{
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
}}
params.conv_algo_desp = desp;
params.input = inp;
params.weight = weight.view(channel_k, -1, channel_c);
params.output = output;
params.mask_width = mask_width;
params.alpha = alpha;
params.beta = beta;
params.stream = stream_int;
params.mask_argsort = mask_argsort;
params.indices = indices;
params.mask = mask;
params.mask_output = mask_output;
// if (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight){{
// TV_ASSERT_RT_ERR(!mask_output.empty(), "error");
// }}
if (op_type_cpp == tv::gemm::ConvOpType::kBackwardInput){{
params.reverse_mask = reverse_mask;
}}
params.mask_filter = mask_filter;
if (desp.split_k_serial() && (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight)){{
splitk_tests = {{{', '.join(map(str, SPCONV_BWD_SPLITK))}}};
}} else {{
splitk_tests = {{1}};
}}
for (auto spk : splitk_tests){{
float total_time = 0.0;
params.split_k_slices = spk;
for (int j = 0; j < num_run; ++j){{
auto ev_start = tv::CUDAEvent();
auto ev_end = tv::CUDAEvent();
ev_start.record(stream_int);
ConvMain::implicit_gemm2(params);
ev_end.record(stream_int);
if (j > 0){{
// skip first run
total_time += tv::CUDAEvent::sync_and_duration(ev_start, ev_end);
}}
}}
total_time /= (num_run - 1);
times.push_back(total_time);
all_profile_res.push_back(ConvTuneResult(desp, arch, spk));
}}
}}
TV_ASSERT_RT_ERR(!all_profile_res.empty(), "can't find suitable algorithm for", op_type);
auto min_idx = std::min_element(times.begin(), times.end()) - times.begin();
auto min_tune_res = all_profile_res[min_idx];
if (op_type_cpp != tv::gemm::ConvOpType::kBackwardWeight){{
mask_width = -1;
}}
algo_cache_key_t key;
key = std::make_tuple(int(inp.dtype()), int(weight.dtype()),
int(output.dtype()), channel_k, channel_c, std::get<0>(arch), std::get<1>(arch), mask_width);
{{
std::lock_guard<std::mutex> guard(mutex_);
if (op_type_cpp == tv::gemm::ConvOpType::kForward){{
kc_forward_cache_[key] = min_tune_res;
}}
else if (op_type_cpp == tv::gemm::ConvOpType::kBackwardInput){{
kc_dgrad_cache_[key] = min_tune_res;
}}
else if (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight){{
kc_wgrad_cache_[key] = min_tune_res;
}}
else{{
TV_THROW_RT_ERR("not implemented");
}}
}}
return std::make_tuple(min_tune_res, times[min_idx]);
""")
return code.ret(
"std::tuple<ConvTuneResult, float>",
pyanno=
"Tuple[spconv.core_cc.csrc.sparse.convops.ConvTuneResult, float]")
@pccm.pybind.mark
@pccm.member_function
def get_tuned_algo(self):
code = pccm.code()
code.arg("op_type", "int")
code.arg("i_dtype, w_dtype, o_dtype", "int")
code.arg("k, c", "int")
code.arg("arch", "std::tuple<int, int>")
code.arg("mask_width", "int", "-1")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code.ret("std::tuple<ConvTuneResult, bool>")
code.raw(f"""
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
if (op_type_cpp != tv::gemm::ConvOpType::kBackwardWeight){{
mask_width = -1;
}}
algo_cache_key_t key;
key = std::make_tuple(i_dtype, w_dtype, o_dtype, k, c,
std::get<0>(arch), std::get<1>(arch), mask_width);
ConvTuneResult res;
bool exists = false;
{{
std::lock_guard<std::mutex> guard(mutex_);
if (op_type_cpp == tv::gemm::ConvOpType::kForward){{
if (kc_forward_cache_.find(key) != kc_forward_cache_.end()){{
res = kc_forward_cache_.at(key);
exists = true;
}}
}}
else if (op_type_cpp == tv::gemm::ConvOpType::kBackwardInput){{
if (kc_dgrad_cache_.find(key) != kc_dgrad_cache_.end()){{
res = kc_dgrad_cache_.at(key);
exists = true;
}}
}}
else if (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight){{
if (kc_wgrad_cache_.find(key) != kc_wgrad_cache_.end()){{
res = kc_wgrad_cache_.at(key);
exists = true;
}}
}}
else{{
TV_THROW_RT_ERR("not implemented");
}}
}}
return std::make_tuple(res, exists);
""")
return code.ret("std::tuple<ConvTuneResult, bool>")
@pccm.pybind.mark
@pccm.member_function
def run_with_tuned_result(self):
code = pccm.code()
code.arg("profile_res", "ConvTuneResult")
code.arg("op_type", "int")
code.arg("inp, weight, output", "tv::Tensor")
code.arg("mask, mask_argsort, mask_output, indices", "tv::Tensor")
code.arg("reverse_mask", "bool")
code.arg("mask_filter", "uint32_t", "0xffffffff")
code.arg("mask_width", "int", "-1")
code.arg("alpha", "float", "1.0")
code.arg("beta", "float", "0.0")
code.arg("stream_int", f"std::uintptr_t", "0")
code.arg("workspace", "tv::Tensor", "tv::Tensor()",
"cumm.tensorview.Tensor = Tensor()")
code.arg("verbose", f"bool", "false")
code.arg("timer", "tv::CUDAKernelTimer", "tv::CUDAKernelTimer(false)",
"cumm.tensorview.CUDAKernelTimer = CUDAKernelTimer(false)")
code.arg("force_nvrtc", f"bool", "false")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code
code.raw(f"""
auto desp = profile_res.algo_desp;
if (force_nvrtc){{
desp.is_nvrtc = true;
}}
int split_k_slices = 1;
if (profile_res.splitk > 1){{
split_k_slices = profile_res.splitk;
}}
int channel_k = output.dim(1);
int channel_c = inp.dim(1);
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
auto arch = profile_res.arch;
tv::gemm::ConvParams params({NDIM_DONT_CARE}, op_type_cpp, timer);
bool desp_is_static = prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end();
if (desp.is_nvrtc && (desp_is_static || force_nvrtc)){{
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
}}
params.conv_algo_desp = desp;
params.input = inp;
params.weight = weight.view(channel_k, -1, channel_c);
params.output = output;
params.verbose = verbose;
params.split_k_slices = split_k_slices;
params.alpha = alpha;
params.beta = beta;
params.stream = stream_int;
params.mask_argsort = mask_argsort;
params.indices = indices;
params.mask = mask;
params.mask_filter = mask_filter;
params.mask_width = mask_width;
params.mask_output = mask_output;
params.reverse_mask = reverse_mask;
if (timer.enable()){{
params.timer = timer;
}}
params.workspace = workspace;
ConvMain::implicit_gemm2(params);
""")
return code
@pccm.pybind.mark
@pccm.member_function
def query_workspace_size(self):
code = pccm.code()
code.arg("desp", "tv::gemm::ConvAlgoDesp")
code.arg("splitk", "int")
code.arg("op_type, N, C, K, kv", "int")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code.ret("int")
code.raw(f'''
auto mnk = ConvMain::extract_mnk(op_type, N, C, K, kv, -1, -1, true);
return desp.query_conv_workspace_size(
std::get<0>(mnk), std::get<1>(mnk), std::get<2>(mnk),
splitk, kv);
''')
return code.ret("int")
class ConvGemmOps(pccm.ParameterizedClass):
def __init__(self, gemm_tuner: GemmTunerSimple,
conv_tuner: ConvTunerSimple):
super().__init__()
self.add_dependency(
ExternalAllocator,
GemmTuneResult,
ConvTuneResult,
ExternalSpconvMatmul,
)
self.add_param_class("gemm", gemm_tuner, "GemmTuner")
self.add_param_class("conv", conv_tuner, "ConvTuner")
@pccm.pybind.mark
@pccm.static_function
def get_compute_capability(self):
code = pccm.code()
code.arg("index", "int", "-1")
code.raw(f"""
if (index == -1){{
checkCudaErrors(cudaGetDevice(&index));
}}
#ifdef TV_CUDA
cudaDeviceProp prop;
checkCudaErrors(cudaGetDeviceProperties(&prop, index));
return std::make_tuple(prop.major, prop.minor);
#else
return std::make_tuple(-1, -1);
#endif
""")
return code.ret("std::tuple<int, int>")
@pccm.pybind.mark
@pccm.static_function
def indice_conv(self):
"""1. this function need to take a out features
that from subm first mm.
2. this function don't support CPU.
"""
code = pccm.code()
code.add_dependency(GatherCPU)
code.arg("allocator", "ExternalAllocator&")
code.arg("ext_mm", "ExternalSpconvMatmul&")
code.arg("gemm_tuner", "GemmTuner&")
code.arg("all_w_is_krsc, filter_hwio", "bool")
code.arg("features, filters, indice_pairs", "tv::Tensor")
code.arg("indice_pair_num", "tv::Tensor")
code.arg("num_activate_out", "int")
code.arg("inverse", "bool", "false")
code.arg("subm", "bool", "false")
code.arg("algo", "int", f"{ConvAlgo.Native.value}")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.raw(f"""
int kv_dim, out_channel, kv;
std::vector<int64_t> filter_shape_per_kv;
bool is_KC_not_CK;
if (!all_w_is_krsc){{
kv_dim = 0;
is_KC_not_CK = !filter_hwio;
if (filter_hwio){{
out_channel = filters.dim(-1);
filter_shape_per_kv = {{filters.dim(-2), out_channel}};
}}else{{
out_channel = filters.dim(-2);
filter_shape_per_kv = {{out_channel, filters.dim(-1)}};
}}
filters = filters.view(-1, filters.dim(-2), filters.dim(-1));
kv = filters.dim(0);
}}else{{
kv_dim = 1;
out_channel = filters.dim(0);
filters = filters.view(out_channel, -1, filters.dim(-1));
is_KC_not_CK = true;
kv = filters.dim(1);
filter_shape_per_kv = {{out_channel, filters.dim(-1)}};
}}
int kv_center = kv / 2;
tv::Tensor out_features;
if (subm){{
out_features = ext_mm.indice_conv_init_gemm({pccm.literal(AllocKeys.Features)},
{pccm.literal(AllocKeys.Filters)}, all_w_is_krsc,
is_KC_not_CK, kv_center, out_channel);
}}else{{
out_features = allocator.zeros({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, features.dtype(), features.device());
}}
if (kv == 1 && subm){{
return;
}}
auto indice_pair_num_cpu = indice_pair_num.cpu();
auto indice_pair_num_cpu_ptr = indice_pair_num_cpu.data_ptr<int>();
int maxnhot = 0;
bool all_zero = true;
for (int i = 0; i < kv; ++i){{
if (indice_pair_num_cpu_ptr[i] != 0){{
indice_pair_num_cpu_ptr[i] = std::min(indice_pair_num_cpu_ptr[i], int(indice_pairs.dim(2)));
all_zero = false;
maxnhot = std::max(maxnhot, indice_pair_num_cpu_ptr[i]);
}}
}}
if (subm && all_zero){{
return;
}}
bool inited = subm;
auto a = features;
auto c = out_features;
auto pair_in = indice_pairs[int(inverse)];
auto pair_out = indice_pairs[int(!inverse)];
if (features.is_cpu()){{
TV_ASSERT_RT_ERR(filters.is_cpu() && indice_pairs.is_cpu(), "error");
auto inp_buffer = allocator.empty({pccm.literal(AllocKeys.InpBuffer)},
{{maxnhot, features.dim(1)}}, features.dtype(), -1);
auto out_buffer = allocator.empty({pccm.literal(AllocKeys.OutBuffer)},
{{maxnhot, out_features.dim(1)}}, out_features.dtype(), -1);
for (int i = 0; i < kv; ++i){{
int nhot = indice_pair_num_cpu_ptr[i];
if (subm && i == kv_center){{
continue;
}}
if (subm && i > kv_center){{
nhot = indice_pair_num_cpu_ptr[kv - i - 1];
}}
if (nhot <= 0){{
continue;
}}
auto inp_indices = pair_in[i].slice_first_axis(0, nhot);
auto out_indices = pair_out[i].slice_first_axis(0, nhot);
GatherCPU::gather(inp_buffer, a, inp_indices);
ext_mm.indice_conv_cpu_gemm({pccm.literal(AllocKeys.InpBuffer)},
{pccm.literal(AllocKeys.OutBuffer)},
{pccm.literal(AllocKeys.Filters)}, all_w_is_krsc,
is_KC_not_CK, nhot, i);
GatherCPU::scatter_add(c, out_buffer, out_indices);
}}
return;
}}
""")
if CUMM_CPU_ONLY_BUILD:
return code
code.raw(f"""
int profile_idx = kv_center;
if (subm)
profile_idx = kv_center - 1;
int nhot_profile = indice_pair_num_cpu_ptr[profile_idx];
if (nhot_profile == 0){{
profile_idx = 0;
for (int i = 0; i < kv; ++i){{
int nhot = indice_pair_num_cpu_ptr[i];
if (nhot > nhot_profile){{
nhot_profile = nhot;
profile_idx = i;
}}
}}
}}
TV_ASSERT_RT_ERR(nhot_profile > 0, "this shouldn't happen");
auto arch = get_compute_capability();
auto a_shape = a.shape();
auto c_shape = c.shape();
int sac_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAC);
auto tuned_res_exist = gemm_tuner.get_tuned_algo(
int(a.dtype()),
int(filters.dtype()),
int(c.dtype()),
std::vector<int64_t>(a_shape.begin(), a_shape.end()),
filter_shape_per_kv,
std::vector<int64_t>(c_shape.begin(), c_shape.end()),
false,
is_KC_not_CK,
false,
arch,
sac_shuffle_type,
{{nhot_profile}},
{{}},
{{nhot_profile}},
{AlgoHint.Fowrard.value});
auto tune_res = std::get<0>(tuned_res_exist);
auto exists = std::get<1>(tuned_res_exist);
if (!exists){{
auto inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile);
auto out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile);
auto filter = filters.select(kv_dim, profile_idx);
auto tune_res_time = gemm_tuner.tune_and_cache(
a,
filter,
c,
false,
is_KC_not_CK,
false,
arch,
sac_shuffle_type,
inp_indices,
tv::Tensor(),
out_indices,
{AlgoHint.Fowrard.value},
1.0,
0.0,
stream_int);
tune_res = std::get<0>(tune_res_time);
}}
for (int i = 0; i < kv; ++i){{
int nhot = indice_pair_num_cpu_ptr[i];
if (subm && i == kv_center){{
continue;
}}
if (subm && i > kv_center){{
nhot = indice_pair_num_cpu_ptr[kv - i - 1];
}}
if (nhot <= 0){{
continue;
}}
auto inp_indices = pair_in[i].slice_first_axis(0, nhot);
auto out_indices = pair_out[i].slice_first_axis(0, nhot);
auto b = filters.select(kv_dim, i);
float beta = inited ? 1.0 : 0.0;
gemm_tuner.run_with_tuned_result(
tune_res,
a,
b,
c,
false,
is_KC_not_CK,
false,
arch,
stream_int,
sac_shuffle_type,
inp_indices,
tv::Tensor(),
out_indices,
{AlgoHint.Fowrard.value},
1.0,
beta);
inited = true;
}}
""")
return code
@pccm.pybind.mark
@pccm.static_function
def indice_conv_backward(self):
code = pccm.code()
code.add_dependency(GatherCPU)
code.arg("allocator", "ExternalAllocator&")
code.arg("ext_mm", "ExternalSpconvMatmul&")
code.arg("gemm_tuner", "GemmTuner&")
code.arg("all_w_is_krsc, filter_hwio", "bool")
code.arg("features, filters, out_bp, indice_pairs", "tv::Tensor")
code.arg("indice_pair_num", "tv::Tensor")
code.arg("inverse", "bool", "false")
code.arg("subm", "bool", "false")
code.arg("algo", "int", f"{ConvAlgo.Native.value}")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.raw(f"""
int kv_dim, out_channel, kv;
std::vector<int64_t> filter_shape_per_kv;
auto prev_filter_shape_vec = filters.shape_vector();
bool is_KC_not_CK;
if (!all_w_is_krsc){{
kv_dim = 0;
is_KC_not_CK = !filter_hwio;
if (filter_hwio){{
out_channel = filters.dim(-1);
filter_shape_per_kv = {{filters.dim(-2), out_channel}};
}}else{{
out_channel = filters.dim(-2);
filter_shape_per_kv = {{out_channel, filters.dim(-1)}};
}}
filters = filters.view(-1, filters.dim(-2), filters.dim(-1));
kv = filters.dim(0);
}}else{{
kv_dim = 1;
out_channel = filters.dim(0);
filters = filters.view(out_channel, -1, filters.dim(-1));
is_KC_not_CK = true;
kv = filters.dim(1);
filter_shape_per_kv = {{out_channel, filters.dim(-1)}};
}}
int kv_center = kv / 2;
tv::Tensor din;
auto dfilters = allocator.zeros({pccm.literal(AllocKeys.DFilters)},
prev_filter_shape_vec, features.dtype(), features.device());
dfilters = dfilters.view(filters.shape());
if (subm){{
din = ext_mm.indice_conv_bwd_init_gemm({pccm.literal(AllocKeys.Features)},
{pccm.literal(AllocKeys.Filters)}, {pccm.literal(AllocKeys.OutBp)},
{pccm.literal(AllocKeys.DFilters)},
all_w_is_krsc,
is_KC_not_CK, kv_center);
}}else{{
din = allocator.zeros({pccm.literal(AllocKeys.DIn)},
features.shape_vector(), features.dtype(), features.device());
}}
if (kv == 1 && subm){{
return;
}}
auto indice_pair_num_cpu = indice_pair_num.cpu();
auto indice_pair_num_cpu_ptr = indice_pair_num_cpu.data_ptr<int>();
int maxnhot = 0;
bool all_zero = true;
for (int i = 0; i < kv; ++i){{
if (indice_pair_num_cpu_ptr[i] != 0){{
indice_pair_num_cpu_ptr[i] = std::min(indice_pair_num_cpu_ptr[i], int(indice_pairs.dim(2)));
all_zero = false;
maxnhot = std::max(maxnhot, indice_pair_num_cpu_ptr[i]);
}}
}}
if (subm && all_zero){{
return;
}}
bool inited = subm;
auto pair_in = indice_pairs[int(inverse)];
auto pair_out = indice_pairs[int(!inverse)];
if (features.is_cpu()){{
TV_ASSERT_RT_ERR(filters.is_cpu() && indice_pairs.is_cpu(), "error");
auto inp_buffer = allocator.empty({pccm.literal(AllocKeys.InpBuffer)},
{{maxnhot, features.dim(1)}}, features.dtype(), -1);
auto out_buffer = allocator.empty({pccm.literal(AllocKeys.OutBuffer)},
{{maxnhot, out_bp.dim(1)}}, out_bp.dtype(), -1);
for (int i = 0; i < kv; ++i){{
int nhot = indice_pair_num_cpu_ptr[i];
if (subm && i == kv_center){{
continue;
}}
if (subm && i > kv_center){{
nhot = indice_pair_num_cpu_ptr[kv - i - 1];
}}
if (nhot <= 0){{
continue;
}}
auto inp_indices = pair_in[i].slice_first_axis(0, nhot);
auto out_indices = pair_out[i].slice_first_axis(0, nhot);
GatherCPU::gather(inp_buffer, features, inp_indices);
GatherCPU::gather(out_buffer, out_bp, out_indices);
ext_mm.indice_conv_bwd_cpu_gemm({pccm.literal(AllocKeys.InpBuffer)},
{pccm.literal(AllocKeys.OutBuffer)},
{pccm.literal(AllocKeys.Filters)},
{pccm.literal(AllocKeys.DFilters)}, all_w_is_krsc,
is_KC_not_CK, nhot, i);
GatherCPU::scatter_add(din, inp_buffer, inp_indices);
}}
return;
}}
""")
if CUMM_CPU_ONLY_BUILD:
return code
code.raw(f"""
int profile_idx = kv_center;
if (subm)
profile_idx = kv_center - 1;
int nhot_profile = indice_pair_num_cpu_ptr[profile_idx];
if (nhot_profile == 0){{
profile_idx = 0;
for (int i = 0; i < kv; ++i){{
int nhot = indice_pair_num_cpu_ptr[i];
if (nhot > nhot_profile){{
nhot_profile = nhot;
profile_idx = i;
}}
}}
}}
TV_ASSERT_RT_ERR(nhot_profile > 0, "this shouldn't happen");
auto arch = get_compute_capability();
int sac_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAC);
int sab_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAB);
auto dgrad_tuned_res_exist = gemm_tuner.get_tuned_algo(
int(out_bp.dtype()),
int(filters.dtype()),
int(din.dtype()),
out_bp.shape_vector(),
filter_shape_per_kv,
din.shape_vector(),
false,
!is_KC_not_CK,
false,
arch,
sac_shuffle_type,
{{nhot_profile}},
{{}},
{{nhot_profile}},
{AlgoHint.BackwardInput.value});
auto tuned_res_dgrad = std::get<0>(dgrad_tuned_res_exist);
auto dgrad_exists = std::get<1>(dgrad_tuned_res_exist);
if (!dgrad_exists){{
auto inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile);
auto out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile);
auto filter = filters.select(kv_dim, profile_idx);
auto tune_res_time = gemm_tuner.tune_and_cache(
out_bp,
filter,
din,
false,
!is_KC_not_CK,
false,
arch,
sac_shuffle_type,
out_indices,
tv::Tensor(),
inp_indices,
{AlgoHint.BackwardInput.value},
1.0,
0.0,
stream_int);
tuned_res_dgrad = std::get<0>(tune_res_time);
}}
tv::Tensor a_wgrad, b_wgrad;
if (is_KC_not_CK){{
a_wgrad = out_bp;
b_wgrad = features;
}}
else{{
a_wgrad = features;
b_wgrad = out_bp;
}}
auto wgrad_tuned_res_exist = gemm_tuner.get_tuned_algo(
int(a_wgrad.dtype()),
int(b_wgrad.dtype()),
int(filters.dtype()),
a_wgrad.shape_vector(),
b_wgrad.shape_vector(),
filter_shape_per_kv,
true,
false,
false,
arch,
sab_shuffle_type,
{{nhot_profile}},
{{nhot_profile}},
{{}},
{AlgoHint.BackwardWeight.value});
auto tuned_res_wgrad = std::get<0>(wgrad_tuned_res_exist);
auto wgrad_exists = std::get<1>(wgrad_tuned_res_exist);
if (!wgrad_exists){{
auto inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile);
auto out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile);
auto dfilter = dfilters.select(kv_dim, profile_idx);
tv::Tensor a_inds_wgrad, b_inds_wgrad;
if (is_KC_not_CK){{
a_inds_wgrad = out_indices;
b_inds_wgrad = inp_indices;
}}else{{
a_inds_wgrad = inp_indices;
b_inds_wgrad = out_indices;
}}
auto tune_res_time = gemm_tuner.tune_and_cache(
a_wgrad,
b_wgrad,
dfilter,
true,
false,
false,
arch,
sab_shuffle_type,
a_inds_wgrad,
b_inds_wgrad,
tv::Tensor(),
{AlgoHint.BackwardWeight.value},
1.0,
0.0,
stream_int);
tuned_res_wgrad = std::get<0>(tune_res_time);
}}
std::vector<int64_t> a_shape{{maxnhot, out_bp.dim(1)}};
std::vector<int64_t> b_shape{{maxnhot, features.dim(1)}};
if (!is_KC_not_CK){{
std::swap(a_shape, b_shape);
}}
auto mnk = GemmTuner::extract_mnk_vector(a_shape, b_shape,
tuned_res_wgrad.algo_desp.trans_a(),
tuned_res_wgrad.algo_desp.trans_b(),
tuned_res_wgrad.algo_desp.trans_c(),
sab_shuffle_type,
{{maxnhot}}, {{maxnhot}}, {{}});
auto ws_size = tuned_res_wgrad.algo_desp.query_workspace_size(
std::get<0>(mnk), std::get<1>(mnk), std::get<2>(mnk), tuned_res_wgrad.splitk);
ExternalAllocator::guard_t workspace_guard;
tv::Tensor workspace;
if (ws_size > 0){{
workspace_guard = allocator.empty_guard({{int64_t(ws_size)}}, tv::uint8, 0);
workspace = workspace_guard->tensor;
}}
for (int i = 0; i < kv; ++i){{
int nhot = indice_pair_num_cpu_ptr[i];
if (subm && i == kv_center){{
continue;
}}
if (subm && i > kv_center){{
nhot = indice_pair_num_cpu_ptr[kv - i - 1];
}}
if (nhot <= 0){{
continue;
}}
auto inp_indices = pair_in[i].slice_first_axis(0, nhot);
auto out_indices = pair_out[i].slice_first_axis(0, nhot);
auto filter_i = filters.select(kv_dim, i);
float beta = inited ? 1.0 : 0.0;
gemm_tuner.run_with_tuned_result(
tuned_res_dgrad,
out_bp,
filter_i,
din,
false,
!is_KC_not_CK,
false,
arch,
stream_int,
sac_shuffle_type,
out_indices,
tv::Tensor(),
inp_indices,
{AlgoHint.BackwardInput.value},
1.0,
beta);
tv::Tensor a = out_bp;
tv::Tensor b = features;
tv::Tensor a_inds = out_indices;
tv::Tensor b_inds = inp_indices;
if (!is_KC_not_CK){{
std::swap(a, b);
std::swap(a_inds, b_inds);
}}
gemm_tuner.run_with_tuned_result(
tuned_res_wgrad,
a,
b,
dfilters.select(kv_dim, i),
true,
false,
false,
arch,
stream_int,
sab_shuffle_type,
a_inds,
b_inds,
tv::Tensor(),
{AlgoHint.BackwardWeight.value},
1.0,
beta);
inited = true;
}}
""")
return code
@pccm.pybind.mark
@pccm.static_function
def implicit_gemm(self):
code = pccm.code()
code.arg("allocator", "ExternalAllocator&")
code.arg("conv_tuner", "ConvTuner&")
code.arg("features, filters, pair_fwd", "tv::Tensor")
code.arg("pair_mask_fwd_splits, mask_argsort_fwd_splits",
"std::vector<tv::Tensor>")
code.arg("num_activate_out", "int")
code.arg("masks", "tv::Tensor")
code.arg("is_train, is_subm", "bool", "false")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("timer", "tv::CUDAKernelTimer", "tv::CUDAKernelTimer(false)",
"cumm.tensorview.CUDAKernelTimer = CUDAKernelTimer(False)")
code.arg("auto_fp32_accum", "bool", "true")
code.arg("fp32_accum", "bool", "false")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code.ret("int")
code.raw(f"""
uint32_t* mask_ptr = masks.data_ptr<uint32_t>();
int num_mask = masks.dim(0);
int out_channel = filters.dim(0);
int in_channel = filters.dim(-1);
int num_split = pair_mask_fwd_splits.size();
TV_ASSERT_RT_ERR(num_mask == num_split, "error");
filters = filters.view(out_channel, -1, in_channel);
tv::Tensor out_features;
if (is_subm){{
out_features = allocator.empty({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, features.dtype(), features.device());
}}else{{
out_features = allocator.zeros({pccm.literal(AllocKeys.OutFeatures)},
{{num_activate_out, out_channel}}, features.dtype(), features.device());
}}
auto arch = get_compute_capability();
constexpr auto kForwardInt = static_cast<int>(tv::gemm::ConvOpType::kForward);
constexpr auto kChannelLastInt = static_cast<int>(tv::gemm::ConvLayoutType::kChannelLast);
auto tuned_res_exist = conv_tuner.get_tuned_algo(
kForwardInt,
int(features.dtype()),
int(filters.dtype()),
int(out_features.dtype()),
out_channel, in_channel, arch);
auto tune_res = std::get<0>(tuned_res_exist);
auto exists = std::get<1>(tuned_res_exist);
if (!exists){{
auto tune_res_time = conv_tuner.tune_and_cache(
kForwardInt,
features, filters, out_features,
kChannelLastInt,
kChannelLastInt,
kChannelLastInt,
1, 1, 1,
arch,
pair_mask_fwd_splits[0].type_view(tv::uint32),
mask_argsort_fwd_splits[0],
pair_fwd,
false, // reverse_mask
mask_ptr[0], // mask_filter
-1,
tv::Tensor(), // mask_output
1.0, 0.0,
stream_int,
auto_fp32_accum,
fp32_accum);
tune_res = std::get<0>(tune_res_time);
}}
int mask_width = tune_res.algo_desp.tile_shape[0];
tv::Tensor mask_output_fwd;
std::vector<tv::Tensor> mask_output_fwd_splits;
if (is_train){{
mask_output_fwd = allocator.empty({pccm.literal(AllocKeys.MaskOutputFwd)},
{{num_split, tv::div_up(num_activate_out, mask_width)}},
tv::uint32, features.device());
for (int i = 0; i < num_split; ++i){{
mask_output_fwd_splits.push_back(mask_output_fwd[i]);
}}
}}else{{
for (int i = 0; i < num_split; ++i){{
mask_output_fwd_splits.push_back(tv::Tensor());
}}
}}
for (int j = 0; j < num_split; ++j){{
float beta = j == 0 ? 0 : 1;
conv_tuner.run_with_tuned_result(
tune_res,
kForwardInt,
features,
filters,
out_features,
pair_mask_fwd_splits[j].type_view(tv::uint32),
mask_argsort_fwd_splits[j],
mask_output_fwd_splits[j],
pair_fwd,
false, // reverse_mask
mask_ptr[j],
-1, // mask_width
1.0, beta,
stream_int,
tv::Tensor(), // workspace
false, // verbose
timer);
}}
return mask_width;
""")
return code.ret("int")
@pccm.pybind.mark
@pccm.static_function
def implicit_gemm_backward(self):
code = pccm.code()
code.arg("allocator", "ExternalAllocator&")
code.arg("conv_tuner", "ConvTuner&")
code.arg("features, filters, out_bp, pair_fwd, pair_bwd", "tv::Tensor")
code.arg("pair_mask_fwd_splits, pair_mask_bwd_splits",
"std::vector<tv::Tensor>")
code.arg("mask_argsort_fwd_splits, mask_argsort_bwd_splits",
"std::vector<tv::Tensor>")
code.arg("mask_output_fwd", "tv::Tensor")
code.arg("masks", "tv::Tensor")
code.arg("mask_width", "int")
code.arg("is_subm", "bool")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("timer", "tv::CUDAKernelTimer", "tv::CUDAKernelTimer(false)",
"cumm.tensorview.CUDAKernelTimer = CUDAKernelTimer(False)")
code.arg("auto_fp32_accum", "bool", "true")
code.arg("fp32_accum", "bool", "false")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code
code.raw(f"""
auto filters_shape = filters.shape();
auto filters_shape_vec = filters.shape_vector();
uint32_t* mask_ptr = masks.data_ptr<uint32_t>();
int num_mask = masks.dim(0);
int out_channel = filters.dim(0);
int in_channel = filters.dim(-1);
int num_split = pair_mask_fwd_splits.size();
TV_ASSERT_RT_ERR(num_mask == num_split, "error");
filters = filters.view(out_channel, -1, in_channel);
int kv = filters.dim(1);
tv::Tensor din;
if (is_subm){{
din = allocator.empty({pccm.literal(AllocKeys.DIn)},
features.shape_vector(), features.dtype(), features.device());
}}else{{
din = allocator.zeros({pccm.literal(AllocKeys.DIn)},
features.shape_vector(), features.dtype(), features.device());
}}
tv::Tensor dfilters = allocator.zeros({pccm.literal(AllocKeys.DFilters)},
filters_shape_vec, filters.dtype(), filters.device());
dfilters = dfilters.view(out_channel, -1, in_channel);
constexpr auto kForwardInt = static_cast<int>(tv::gemm::ConvOpType::kForward);
constexpr auto kBackwardInputInt = static_cast<int>(tv::gemm::ConvOpType::kBackwardInput);
constexpr auto kBackwardWeightInt = static_cast<int>(tv::gemm::ConvOpType::kBackwardWeight);
constexpr auto kChannelLastInt = static_cast<int>(tv::gemm::ConvLayoutType::kChannelLast);
auto arch = get_compute_capability();
auto dgrad_tuned_res_exist = conv_tuner.get_tuned_algo(
kBackwardInputInt,
int(din.dtype()),
int(filters.dtype()),
int(out_bp.dtype()),
out_channel, in_channel, arch);
auto wgrad_tuned_res_exist = conv_tuner.get_tuned_algo(
kBackwardWeightInt,
int(features.dtype()),
int(dfilters.dtype()),
int(out_bp.dtype()),
out_channel, in_channel, arch, mask_width);
auto dgrad_tune_res = std::get<0>(dgrad_tuned_res_exist);
auto dgrad_exists = std::get<1>(dgrad_tuned_res_exist);
auto wgrad_tune_res = std::get<0>(wgrad_tuned_res_exist);
auto wgrad_exists = std::get<1>(wgrad_tuned_res_exist);
if (!dgrad_exists){{
tv::Tensor mask, mask_argsort;
if (is_subm){{
mask = pair_mask_fwd_splits[0].type_view(tv::uint32);
mask_argsort = mask_argsort_fwd_splits[0];
}}else{{
mask = pair_mask_bwd_splits[0].type_view(tv::uint32);
mask_argsort = mask_argsort_bwd_splits[0];
}}
auto tune_res_time = conv_tuner.tune_and_cache(
kBackwardInputInt,
din, filters, out_bp,
kChannelLastInt,
kChannelLastInt,
kChannelLastInt,
1, 1, 1,
arch,
mask,
mask_argsort,
pair_bwd,
is_subm, // reverse_mask
mask_ptr[0], // mask_filter
-1, // mask width
tv::Tensor(), // mask_output
1.0, 0.0,
stream_int,
auto_fp32_accum,
fp32_accum);
dgrad_tune_res = std::get<0>(tune_res_time);
}}
if (!wgrad_exists){{
auto tune_res_time = conv_tuner.tune_and_cache(
kBackwardWeightInt,
features, dfilters, out_bp,
kChannelLastInt,
kChannelLastInt,
kChannelLastInt,
1, 1, 1,
arch,
mask_output_fwd[0].type_view(tv::uint32),
mask_argsort_fwd_splits[0],
pair_fwd,
false, // reverse_mask
mask_ptr[0], // mask_filter
mask_width,
tv::Tensor(), // mask_output
1.0, 0.0,
stream_int,
auto_fp32_accum,
fp32_accum);
wgrad_tune_res = std::get<0>(tune_res_time);
}}
int ws_size = conv_tuner.query_workspace_size(wgrad_tune_res.algo_desp,
wgrad_tune_res.splitk,
kBackwardWeightInt,
pair_fwd.dim(1), in_channel,
out_channel, kv);
ExternalAllocator::guard_t workspace_guard;
tv::Tensor workspace;
if (ws_size > 0){{
workspace_guard = allocator.empty_guard({{int64_t(ws_size)}}, tv::uint8, 0);
workspace = workspace_guard->tensor;
}}
for (int j = 0; j < num_split; ++j){{
tv::Tensor mask, mask_argsort;
if (is_subm){{
mask = pair_mask_fwd_splits[j].type_view(tv::uint32);
mask_argsort = mask_argsort_fwd_splits[j];
}}else{{
mask = pair_mask_bwd_splits[j].type_view(tv::uint32);
mask_argsort = mask_argsort_bwd_splits[j];
}}
float beta = j == 0 ? 0 : 1;
conv_tuner.run_with_tuned_result(
dgrad_tune_res,
kBackwardInputInt,
din,
filters,
out_bp,
mask,
mask_argsort,
tv::Tensor(), // mask_output
pair_bwd,
is_subm, // reverse_mask
mask_ptr[j],
-1, // mask_width
1.0, beta,
stream_int,
tv::Tensor(), // workspace
false, // verbose
timer);
conv_tuner.run_with_tuned_result(
wgrad_tune_res,
kBackwardWeightInt,
features, dfilters, out_bp,
mask_output_fwd[j].type_view(tv::uint32),
mask_argsort_fwd_splits[j],
tv::Tensor(), // mask_output
pair_fwd,
false, // reverse_mask
mask_ptr[j], // mask_filter
mask_width,
1.0, 0.0,
stream_int,
workspace, // workspace
false, // verbose
timer);
}}
""")
return code
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment