Commit 3b545945 authored by one's avatar one
Browse files

WIP

parent 263d6b47
Pipeline #3583 canceled with stages
......@@ -60,7 +60,7 @@ SPCONV_BWD_SPLITK = list(map(int, os.getenv("SPCONV_BWD_SPLITK", "1,2,4,8,16,32,
SPCONV_NVRTC_MODE = NVRTCMode.ConstantMemory
SPCONV_DEBUG_NVRTC_KERNELS = False
SPCONV_DEBUG_CPP_ONLY = project_is_editable(PACKAGE_NAME)
SPCONV_DEBUG_CPP_ONLY = EDITABLE_INSTALLED
class AllocKeys:
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from enum import Enum
from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgoParams
from cumm.gemm import kernel
......@@ -1338,6 +1339,152 @@ if not SPCONV_INT8_DEBUG:
int8_inference=True),
])
def _dtype_shortcuts(param):
return tuple(
getattr(param, name).shortcut()
for name in ("dtype_a", "dtype_b", "dtype_c")
)
def _all_dtypes_are(param, dtype_shortcut: str):
return all(shortcut == dtype_shortcut for shortcut in _dtype_shortcuts(param))
def _has_any_dtype(param, dtype_shortcuts):
return any(shortcut in dtype_shortcuts for shortcut in _dtype_shortcuts(param))
def _is_fp32_simt_param(param):
return param.algo == GemmAlgo.Simt and _all_dtypes_are(param, "f32")
def _is_ampere_param(param):
return param.algo == GemmAlgo.Ampere
def _is_static_param(param):
return not getattr(param, "is_nvrtc", False)
def _is_non_int8_param(param):
return not getattr(param, "int8_inference", False)
def _is_fp32_ampere_param(param):
return (
_is_ampere_param(param)
and _is_static_param(param)
and _is_non_int8_param(param)
and _all_dtypes_are(param, "f32")
)
def _is_f16_ampere_param(param):
return (
_is_ampere_param(param)
and _is_static_param(param)
and _is_non_int8_param(param)
and _has_any_dtype(param, {"f16"})
)
def _is_ampere_no_int8_static_param(param):
return _is_ampere_param(param) and _is_static_param(param) and _is_non_int8_param(param)
def _is_ampere_int8_param(param):
return _is_ampere_param(param) and getattr(param, "int8_inference", False)
def _is_non_int8_nvrtc_param(param):
return getattr(param, "is_nvrtc", False) and _is_non_int8_param(param)
def _is_fp8_param(param):
return _has_any_dtype(param, {"e4m3", "e5m2"})
def _filter_params(params, predicate):
return [param for param in params if predicate(param)]
def _clear_turing_volta():
global SHUFFLE_TURING_PARAMS, SHUFFLE_VOLTA_PARAMS
global IMPLGEMM_TURING_PARAMS, IMPLGEMM_VOLTA_PARAMS
SHUFFLE_TURING_PARAMS = []
SHUFFLE_VOLTA_PARAMS = []
IMPLGEMM_TURING_PARAMS = []
IMPLGEMM_VOLTA_PARAMS = []
_DTK_KERNEL_FILTER = os.getenv("SPCONV_DTK_KERNEL_FILTER", "").lower()
if _DTK_KERNEL_FILTER == "dtk_smoke":
SHUFFLE_SIMT_PARAMS = _filter_params(SHUFFLE_SIMT_PARAMS, _is_fp32_simt_param)[:4]
SHUFFLE_AMPERE_PARAMS = []
IMPLGEMM_SIMT_PARAMS = _filter_params(IMPLGEMM_SIMT_PARAMS, _is_fp32_simt_param)[:4]
IMPLGEMM_AMPERE_PARAMS = []
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_fp32_simt":
SHUFFLE_SIMT_PARAMS = _filter_params(SHUFFLE_SIMT_PARAMS, _is_fp32_simt_param)
SHUFFLE_AMPERE_PARAMS = []
IMPLGEMM_SIMT_PARAMS = _filter_params(IMPLGEMM_SIMT_PARAMS, _is_fp32_simt_param)
IMPLGEMM_AMPERE_PARAMS = []
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_all_simt":
SHUFFLE_AMPERE_PARAMS = []
IMPLGEMM_AMPERE_PARAMS = []
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_fp32_ampere":
SHUFFLE_SIMT_PARAMS = _filter_params(SHUFFLE_SIMT_PARAMS, _is_fp32_simt_param)
SHUFFLE_AMPERE_PARAMS = _filter_params(SHUFFLE_AMPERE_PARAMS, _is_fp32_ampere_param)
IMPLGEMM_SIMT_PARAMS = _filter_params(IMPLGEMM_SIMT_PARAMS, _is_fp32_simt_param)
IMPLGEMM_AMPERE_PARAMS = _filter_params(IMPLGEMM_AMPERE_PARAMS, _is_fp32_ampere_param)
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_f16_ampere":
SHUFFLE_SIMT_PARAMS = _filter_params(SHUFFLE_SIMT_PARAMS, _is_fp32_simt_param)
SHUFFLE_AMPERE_PARAMS = _filter_params(SHUFFLE_AMPERE_PARAMS, _is_f16_ampere_param)
IMPLGEMM_SIMT_PARAMS = _filter_params(IMPLGEMM_SIMT_PARAMS, _is_fp32_simt_param)
IMPLGEMM_AMPERE_PARAMS = _filter_params(IMPLGEMM_AMPERE_PARAMS, _is_f16_ampere_param)
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_ampere_no_int8":
SHUFFLE_AMPERE_PARAMS = _filter_params(SHUFFLE_AMPERE_PARAMS, _is_ampere_no_int8_static_param)
IMPLGEMM_AMPERE_PARAMS = _filter_params(IMPLGEMM_AMPERE_PARAMS, _is_ampere_no_int8_static_param)
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_int8_ampere":
SHUFFLE_AMPERE_PARAMS = _filter_params(
SHUFFLE_AMPERE_PARAMS,
lambda param: _is_ampere_no_int8_static_param(param) or _is_ampere_int8_param(param))
IMPLGEMM_AMPERE_PARAMS = _filter_params(
IMPLGEMM_AMPERE_PARAMS,
lambda param: _is_ampere_no_int8_static_param(param) or _is_ampere_int8_param(param))
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_nvrtc":
SHUFFLE_AMPERE_PARAMS = _filter_params(
SHUFFLE_AMPERE_PARAMS,
lambda param: _is_ampere_no_int8_static_param(param) or _is_non_int8_nvrtc_param(param))
IMPLGEMM_AMPERE_PARAMS = _filter_params(
IMPLGEMM_AMPERE_PARAMS,
lambda param: _is_ampere_no_int8_static_param(param) or _is_non_int8_nvrtc_param(param))
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_fp8_probe":
SHUFFLE_SIMT_PARAMS = _filter_params(SHUFFLE_SIMT_PARAMS, _is_fp32_simt_param)
SHUFFLE_AMPERE_PARAMS = _filter_params(SHUFFLE_AMPERE_PARAMS, _is_fp8_param)
IMPLGEMM_SIMT_PARAMS = _filter_params(IMPLGEMM_SIMT_PARAMS, _is_fp32_simt_param)
IMPLGEMM_AMPERE_PARAMS = _filter_params(IMPLGEMM_AMPERE_PARAMS, _is_fp8_param)
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_all_static_no_nvrtc":
SHUFFLE_SIMT_PARAMS = _filter_params(SHUFFLE_SIMT_PARAMS, _is_static_param)
SHUFFLE_TURING_PARAMS = _filter_params(SHUFFLE_TURING_PARAMS, _is_static_param)
SHUFFLE_VOLTA_PARAMS = _filter_params(SHUFFLE_VOLTA_PARAMS, _is_static_param)
SHUFFLE_AMPERE_PARAMS = _filter_params(SHUFFLE_AMPERE_PARAMS, _is_static_param)
IMPLGEMM_SIMT_PARAMS = _filter_params(IMPLGEMM_SIMT_PARAMS, _is_static_param)
IMPLGEMM_TURING_PARAMS = _filter_params(IMPLGEMM_TURING_PARAMS, _is_static_param)
IMPLGEMM_VOLTA_PARAMS = _filter_params(IMPLGEMM_VOLTA_PARAMS, _is_static_param)
IMPLGEMM_AMPERE_PARAMS = _filter_params(IMPLGEMM_AMPERE_PARAMS, _is_static_param)
elif _DTK_KERNEL_FILTER == "dtk_all":
pass
ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_AMPERE_PARAMS
ALL_IMPGEMM_PARAMS = IMPLGEMM_SIMT_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_AMPERE_PARAMS
......@@ -36,7 +36,7 @@ class CustomThrustLib(pccm.Class):
super().__init__()
self.add_dependency(ThrustLib)
# https://github.com/NVIDIA/thrust/issues/1401#issuecomment-806403746
if compat.InLinux:
if compat.InLinux and os.getenv("CUMM_DTK_DISABLE_INLINE_PTX", "0") != "1":
self.build_meta.add_public_cflags("nvcc", "-Xcompiler -fno-gnu-unique", "-Xcompiler -fvisibility=hidden")
......
......@@ -35,6 +35,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def arange_kernel(self):
code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("T")
code.arg("data", f"T*")
code.arg("size", f"int")
......@@ -48,6 +49,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def fill_kernel(self):
code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("T")
code.arg("data", f"T*")
code.arg("val", f"T")
......@@ -62,6 +64,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def maximum_value_kernel(self):
code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("T")
code.arg("data", f"T*")
code.arg("val", f"T")
......@@ -723,6 +726,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def build_subm_conv_hash_table(self):
code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("TTable")
code.targ("TLayoutNPQ")
......@@ -806,6 +810,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def calc_subm_conv_indices_mask(self):
code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("TTable")
code.targ("TConvLocIter")
code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from cumm import tensorview as tv
import os
import torch
from typing import Dict, Optional, List, Union
from spconv.constants import AllocKeys
......@@ -100,7 +101,12 @@ def get_current_stream():
def get_arch():
arch = torch.cuda.get_device_capability()
force_arch = os.getenv("SPCONV_FORCE_CUDA_ARCH", "")
if force_arch:
force_arch = force_arch.replace(".", "")
arch = (int(force_arch[:-1]), int(force_arch[-1]))
else:
arch = torch.cuda.get_device_capability()
if not CompileInfo.arch_is_compatible(arch) and not CompileInfo.algo_can_use_ptx((0, 0), arch):
warnings.warn(
f"[WARNING]your gpu arch {arch} isn't compiled in prebuilt, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment