Commit 3b545945 authored by one's avatar one
Browse files

WIP

parent 263d6b47
Pipeline #3583 canceled with stages
...@@ -60,7 +60,7 @@ SPCONV_BWD_SPLITK = list(map(int, os.getenv("SPCONV_BWD_SPLITK", "1,2,4,8,16,32, ...@@ -60,7 +60,7 @@ SPCONV_BWD_SPLITK = list(map(int, os.getenv("SPCONV_BWD_SPLITK", "1,2,4,8,16,32,
SPCONV_NVRTC_MODE = NVRTCMode.ConstantMemory SPCONV_NVRTC_MODE = NVRTCMode.ConstantMemory
SPCONV_DEBUG_NVRTC_KERNELS = False SPCONV_DEBUG_NVRTC_KERNELS = False
SPCONV_DEBUG_CPP_ONLY = project_is_editable(PACKAGE_NAME) SPCONV_DEBUG_CPP_ONLY = EDITABLE_INSTALLED
class AllocKeys: class AllocKeys:
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from enum import Enum from enum import Enum
from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgoParams from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgoParams
from cumm.gemm import kernel from cumm.gemm import kernel
...@@ -1338,6 +1339,152 @@ if not SPCONV_INT8_DEBUG: ...@@ -1338,6 +1339,152 @@ if not SPCONV_INT8_DEBUG:
int8_inference=True), int8_inference=True),
]) ])
def _dtype_shortcuts(param):
return tuple(
getattr(param, name).shortcut()
for name in ("dtype_a", "dtype_b", "dtype_c")
)
def _all_dtypes_are(param, dtype_shortcut: str):
return all(shortcut == dtype_shortcut for shortcut in _dtype_shortcuts(param))
def _has_any_dtype(param, dtype_shortcuts):
return any(shortcut in dtype_shortcuts for shortcut in _dtype_shortcuts(param))
def _is_fp32_simt_param(param):
return param.algo == GemmAlgo.Simt and _all_dtypes_are(param, "f32")
def _is_ampere_param(param):
return param.algo == GemmAlgo.Ampere
def _is_static_param(param):
return not getattr(param, "is_nvrtc", False)
def _is_non_int8_param(param):
return not getattr(param, "int8_inference", False)
def _is_fp32_ampere_param(param):
return (
_is_ampere_param(param)
and _is_static_param(param)
and _is_non_int8_param(param)
and _all_dtypes_are(param, "f32")
)
def _is_f16_ampere_param(param):
return (
_is_ampere_param(param)
and _is_static_param(param)
and _is_non_int8_param(param)
and _has_any_dtype(param, {"f16"})
)
def _is_ampere_no_int8_static_param(param):
return _is_ampere_param(param) and _is_static_param(param) and _is_non_int8_param(param)
def _is_ampere_int8_param(param):
return _is_ampere_param(param) and getattr(param, "int8_inference", False)
def _is_non_int8_nvrtc_param(param):
return getattr(param, "is_nvrtc", False) and _is_non_int8_param(param)
def _is_fp8_param(param):
return _has_any_dtype(param, {"e4m3", "e5m2"})
def _filter_params(params, predicate):
return [param for param in params if predicate(param)]
def _clear_turing_volta():
global SHUFFLE_TURING_PARAMS, SHUFFLE_VOLTA_PARAMS
global IMPLGEMM_TURING_PARAMS, IMPLGEMM_VOLTA_PARAMS
SHUFFLE_TURING_PARAMS = []
SHUFFLE_VOLTA_PARAMS = []
IMPLGEMM_TURING_PARAMS = []
IMPLGEMM_VOLTA_PARAMS = []
_DTK_KERNEL_FILTER = os.getenv("SPCONV_DTK_KERNEL_FILTER", "").lower()
if _DTK_KERNEL_FILTER == "dtk_smoke":
SHUFFLE_SIMT_PARAMS = _filter_params(SHUFFLE_SIMT_PARAMS, _is_fp32_simt_param)[:4]
SHUFFLE_AMPERE_PARAMS = []
IMPLGEMM_SIMT_PARAMS = _filter_params(IMPLGEMM_SIMT_PARAMS, _is_fp32_simt_param)[:4]
IMPLGEMM_AMPERE_PARAMS = []
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_fp32_simt":
SHUFFLE_SIMT_PARAMS = _filter_params(SHUFFLE_SIMT_PARAMS, _is_fp32_simt_param)
SHUFFLE_AMPERE_PARAMS = []
IMPLGEMM_SIMT_PARAMS = _filter_params(IMPLGEMM_SIMT_PARAMS, _is_fp32_simt_param)
IMPLGEMM_AMPERE_PARAMS = []
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_all_simt":
SHUFFLE_AMPERE_PARAMS = []
IMPLGEMM_AMPERE_PARAMS = []
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_fp32_ampere":
SHUFFLE_SIMT_PARAMS = _filter_params(SHUFFLE_SIMT_PARAMS, _is_fp32_simt_param)
SHUFFLE_AMPERE_PARAMS = _filter_params(SHUFFLE_AMPERE_PARAMS, _is_fp32_ampere_param)
IMPLGEMM_SIMT_PARAMS = _filter_params(IMPLGEMM_SIMT_PARAMS, _is_fp32_simt_param)
IMPLGEMM_AMPERE_PARAMS = _filter_params(IMPLGEMM_AMPERE_PARAMS, _is_fp32_ampere_param)
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_f16_ampere":
SHUFFLE_SIMT_PARAMS = _filter_params(SHUFFLE_SIMT_PARAMS, _is_fp32_simt_param)
SHUFFLE_AMPERE_PARAMS = _filter_params(SHUFFLE_AMPERE_PARAMS, _is_f16_ampere_param)
IMPLGEMM_SIMT_PARAMS = _filter_params(IMPLGEMM_SIMT_PARAMS, _is_fp32_simt_param)
IMPLGEMM_AMPERE_PARAMS = _filter_params(IMPLGEMM_AMPERE_PARAMS, _is_f16_ampere_param)
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_ampere_no_int8":
SHUFFLE_AMPERE_PARAMS = _filter_params(SHUFFLE_AMPERE_PARAMS, _is_ampere_no_int8_static_param)
IMPLGEMM_AMPERE_PARAMS = _filter_params(IMPLGEMM_AMPERE_PARAMS, _is_ampere_no_int8_static_param)
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_int8_ampere":
SHUFFLE_AMPERE_PARAMS = _filter_params(
SHUFFLE_AMPERE_PARAMS,
lambda param: _is_ampere_no_int8_static_param(param) or _is_ampere_int8_param(param))
IMPLGEMM_AMPERE_PARAMS = _filter_params(
IMPLGEMM_AMPERE_PARAMS,
lambda param: _is_ampere_no_int8_static_param(param) or _is_ampere_int8_param(param))
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_nvrtc":
SHUFFLE_AMPERE_PARAMS = _filter_params(
SHUFFLE_AMPERE_PARAMS,
lambda param: _is_ampere_no_int8_static_param(param) or _is_non_int8_nvrtc_param(param))
IMPLGEMM_AMPERE_PARAMS = _filter_params(
IMPLGEMM_AMPERE_PARAMS,
lambda param: _is_ampere_no_int8_static_param(param) or _is_non_int8_nvrtc_param(param))
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_fp8_probe":
SHUFFLE_SIMT_PARAMS = _filter_params(SHUFFLE_SIMT_PARAMS, _is_fp32_simt_param)
SHUFFLE_AMPERE_PARAMS = _filter_params(SHUFFLE_AMPERE_PARAMS, _is_fp8_param)
IMPLGEMM_SIMT_PARAMS = _filter_params(IMPLGEMM_SIMT_PARAMS, _is_fp32_simt_param)
IMPLGEMM_AMPERE_PARAMS = _filter_params(IMPLGEMM_AMPERE_PARAMS, _is_fp8_param)
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_all_static_no_nvrtc":
SHUFFLE_SIMT_PARAMS = _filter_params(SHUFFLE_SIMT_PARAMS, _is_static_param)
SHUFFLE_TURING_PARAMS = _filter_params(SHUFFLE_TURING_PARAMS, _is_static_param)
SHUFFLE_VOLTA_PARAMS = _filter_params(SHUFFLE_VOLTA_PARAMS, _is_static_param)
SHUFFLE_AMPERE_PARAMS = _filter_params(SHUFFLE_AMPERE_PARAMS, _is_static_param)
IMPLGEMM_SIMT_PARAMS = _filter_params(IMPLGEMM_SIMT_PARAMS, _is_static_param)
IMPLGEMM_TURING_PARAMS = _filter_params(IMPLGEMM_TURING_PARAMS, _is_static_param)
IMPLGEMM_VOLTA_PARAMS = _filter_params(IMPLGEMM_VOLTA_PARAMS, _is_static_param)
IMPLGEMM_AMPERE_PARAMS = _filter_params(IMPLGEMM_AMPERE_PARAMS, _is_static_param)
elif _DTK_KERNEL_FILTER == "dtk_all":
pass
ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_AMPERE_PARAMS ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_AMPERE_PARAMS
ALL_IMPGEMM_PARAMS = IMPLGEMM_SIMT_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_AMPERE_PARAMS ALL_IMPGEMM_PARAMS = IMPLGEMM_SIMT_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_AMPERE_PARAMS
...@@ -36,7 +36,7 @@ class CustomThrustLib(pccm.Class): ...@@ -36,7 +36,7 @@ class CustomThrustLib(pccm.Class):
super().__init__() super().__init__()
self.add_dependency(ThrustLib) self.add_dependency(ThrustLib)
# https://github.com/NVIDIA/thrust/issues/1401#issuecomment-806403746 # https://github.com/NVIDIA/thrust/issues/1401#issuecomment-806403746
if compat.InLinux: if compat.InLinux and os.getenv("CUMM_DTK_DISABLE_INLINE_PTX", "0") != "1":
self.build_meta.add_public_cflags("nvcc", "-Xcompiler -fno-gnu-unique", "-Xcompiler -fvisibility=hidden") self.build_meta.add_public_cflags("nvcc", "-Xcompiler -fno-gnu-unique", "-Xcompiler -fvisibility=hidden")
......
...@@ -35,6 +35,7 @@ class CudaCommonKernel(pccm.ParameterizedClass): ...@@ -35,6 +35,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def arange_kernel(self): def arange_kernel(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("T") code.targ("T")
code.arg("data", f"T*") code.arg("data", f"T*")
code.arg("size", f"int") code.arg("size", f"int")
...@@ -48,6 +49,7 @@ class CudaCommonKernel(pccm.ParameterizedClass): ...@@ -48,6 +49,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def fill_kernel(self): def fill_kernel(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("T") code.targ("T")
code.arg("data", f"T*") code.arg("data", f"T*")
code.arg("val", f"T") code.arg("val", f"T")
...@@ -62,6 +64,7 @@ class CudaCommonKernel(pccm.ParameterizedClass): ...@@ -62,6 +64,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def maximum_value_kernel(self): def maximum_value_kernel(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("T") code.targ("T")
code.arg("data", f"T*") code.arg("data", f"T*")
code.arg("val", f"T") code.arg("val", f"T")
...@@ -723,6 +726,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -723,6 +726,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def build_subm_conv_hash_table(self): def build_subm_conv_hash_table(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("TTable") code.targ("TTable")
code.targ("TLayoutNPQ") code.targ("TLayoutNPQ")
...@@ -806,6 +810,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -806,6 +810,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def calc_subm_conv_indices_mask(self): def calc_subm_conv_indices_mask(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("TTable") code.targ("TTable")
code.targ("TConvLocIter") code.targ("TConvLocIter")
code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1] code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from cumm import tensorview as tv from cumm import tensorview as tv
import os
import torch import torch
from typing import Dict, Optional, List, Union from typing import Dict, Optional, List, Union
from spconv.constants import AllocKeys from spconv.constants import AllocKeys
...@@ -100,7 +101,12 @@ def get_current_stream(): ...@@ -100,7 +101,12 @@ def get_current_stream():
def get_arch(): def get_arch():
arch = torch.cuda.get_device_capability() force_arch = os.getenv("SPCONV_FORCE_CUDA_ARCH", "")
if force_arch:
force_arch = force_arch.replace(".", "")
arch = (int(force_arch[:-1]), int(force_arch[-1]))
else:
arch = torch.cuda.get_device_capability()
if not CompileInfo.arch_is_compatible(arch) and not CompileInfo.algo_can_use_ptx((0, 0), arch): if not CompileInfo.arch_is_compatible(arch) and not CompileInfo.algo_can_use_ptx((0, 0), arch):
warnings.warn( warnings.warn(
f"[WARNING]your gpu arch {arch} isn't compiled in prebuilt, " f"[WARNING]your gpu arch {arch} isn't compiled in prebuilt, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment