Commit a2dd956c authored by one's avatar one
Browse files

Adapt spconv for DTK SIMT fallback

Add a DTK-specific kernel filter path for running spconv through the
DTK CUDA compatibility layer on BW100. The recommended `dtk_simt`
filter keeps only SIMT kernels, forces SIMT params to static codegen,
and removes Volta/Turing/Ampere TensorOp, int8, and NVRTC paths from
the active kernel set.

Add `dtk_tensorop` as a separate non-default adaptation entry point for
future Ampere TensorOp work. This keeps static non-int8 Ampere TensorOp
params while still excluding Volta/Turing, int8, and NVRTC paths.

Allow fp16 workloads to use SIMT fallback when `SPCONV_DTK_KERNEL_FILTER`
is set to `dtk_simt`. This updates both the Python tuner and generated
C++ ConvTunerSimple logic so fp16 no longer depends on currently
unsupported TensorOp paths on DTK.

Add `SPCONV_FORCE_CUDA_ARCH` to keep runtime dispatch aligned with the
compiled arch list, and keep the BW100 path explicit with `9.3`.

Adjust DTK build/runtime compatibility:
- reuse the guarded editable-install state during constants setup
- skip the Linux Thrust `-fno-gnu-unique` flag under the DTK inline-PTX
  compatibility path
- add launch bounds to helper kernels that are launched with 1024
  threads/block

This leaves full TensorOp, int8, fp8, and NVRTC support out of the
recommended DTK path. Those remain future adaptation work.
parent 263d6b47
Pipeline #3585 failed with stages
in 0 seconds
......@@ -13,6 +13,7 @@
# limitations under the License.
import contextlib
import os
import time
from enum import Enum
from threading import Lock
......@@ -728,7 +729,8 @@ class SimpleConv:
if (desp.tensorop[0] > 0 and inp.dtype == tv.float32
and weight.dtype == tv.float32 and out.dtype == tv.float32):
continue
if arch >= (7, 0) and is_fp16:
allow_dtk_fp16_simt = os.getenv("SPCONV_DTK_KERNEL_FILTER", "").lower() == "dtk_simt"
if arch >= (7, 0) and is_fp16 and not allow_dtk_fp16_simt:
if desp.algo == GemmAlgo.Simt:
continue
if use_f32_as_accum:
......
......@@ -60,7 +60,7 @@ SPCONV_BWD_SPLITK = list(map(int, os.getenv("SPCONV_BWD_SPLITK", "1,2,4,8,16,32,
SPCONV_NVRTC_MODE = NVRTCMode.ConstantMemory
SPCONV_DEBUG_NVRTC_KERNELS = False
SPCONV_DEBUG_CPP_ONLY = project_is_editable(PACKAGE_NAME)
SPCONV_DEBUG_CPP_ONLY = EDITABLE_INSTALLED
class AllocKeys:
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from enum import Enum
from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgoParams
from cumm.gemm import kernel
......@@ -1338,6 +1339,65 @@ if not SPCONV_INT8_DEBUG:
int8_inference=True),
])
def _is_ampere_param(param):
return param.algo == GemmAlgo.Ampere
def _is_simt_param(param):
return param.algo == GemmAlgo.Simt
def _is_static_param(param):
return not getattr(param, "is_nvrtc", False)
def _is_non_int8_param(param):
return not getattr(param, "int8_inference", False)
def _is_ampere_no_int8_static_param(param):
return _is_ampere_param(param) and _is_static_param(param) and _is_non_int8_param(param)
def _filter_params(params, predicate):
return [param for param in params if predicate(param)]
def _force_static_params(params, predicate):
for param in params:
if predicate(param):
param.is_nvrtc = False
return params
def _clear_turing_volta():
global SHUFFLE_TURING_PARAMS, SHUFFLE_VOLTA_PARAMS
global IMPLGEMM_TURING_PARAMS, IMPLGEMM_VOLTA_PARAMS
SHUFFLE_TURING_PARAMS = []
SHUFFLE_VOLTA_PARAMS = []
IMPLGEMM_TURING_PARAMS = []
IMPLGEMM_VOLTA_PARAMS = []
_DTK_KERNEL_FILTER = os.getenv("SPCONV_DTK_KERNEL_FILTER", "").lower()
if _DTK_KERNEL_FILTER == "dtk_simt":
SHUFFLE_SIMT_PARAMS = _force_static_params(SHUFFLE_SIMT_PARAMS, _is_simt_param)
SHUFFLE_AMPERE_PARAMS = []
IMPLGEMM_SIMT_PARAMS = _force_static_params(IMPLGEMM_SIMT_PARAMS, _is_simt_param)
IMPLGEMM_AMPERE_PARAMS = []
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_tensorop":
SHUFFLE_SIMT_PARAMS = _force_static_params(SHUFFLE_SIMT_PARAMS, _is_simt_param)
SHUFFLE_AMPERE_PARAMS = _filter_params(SHUFFLE_AMPERE_PARAMS, _is_ampere_no_int8_static_param)
IMPLGEMM_SIMT_PARAMS = _force_static_params(IMPLGEMM_SIMT_PARAMS, _is_simt_param)
IMPLGEMM_AMPERE_PARAMS = _filter_params(IMPLGEMM_AMPERE_PARAMS, _is_ampere_no_int8_static_param)
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "":
pass
else:
raise ValueError(f"unknown SPCONV_DTK_KERNEL_FILTER: {_DTK_KERNEL_FILTER}")
ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_AMPERE_PARAMS
ALL_IMPGEMM_PARAMS = IMPLGEMM_SIMT_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_AMPERE_PARAMS
......@@ -36,7 +36,7 @@ class CustomThrustLib(pccm.Class):
super().__init__()
self.add_dependency(ThrustLib)
# https://github.com/NVIDIA/thrust/issues/1401#issuecomment-806403746
if compat.InLinux:
if compat.InLinux and os.getenv("CUMM_DTK_DISABLE_INLINE_PTX", "0") != "1":
self.build_meta.add_public_cflags("nvcc", "-Xcompiler -fno-gnu-unique", "-Xcompiler -fvisibility=hidden")
......
import os
from typing import Optional
import pccm
from cumm.common import GemmBasicHost, NlohmannJson, TensorView
......@@ -1058,7 +1059,8 @@ class ConvTunerSimple(pccm.ParameterizedClass):
continue;
}}
}}
if (arch >= std::make_tuple(7, 0) && is_fp16){{
bool allow_dtk_fp16_simt = {pccm.boolean(os.getenv("SPCONV_DTK_KERNEL_FILTER", "").lower() == "dtk_simt")};
if (arch >= std::make_tuple(7, 0) && is_fp16 && !allow_dtk_fp16_simt){{
// skip simt fp16 kernels if we have tensor core
if (desp.algo == {pccm.literal(GemmAlgo.Simt.value)}){{
continue;
......
......@@ -35,6 +35,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def arange_kernel(self):
code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("T")
code.arg("data", f"T*")
code.arg("size", f"int")
......@@ -48,6 +49,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def fill_kernel(self):
code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("T")
code.arg("data", f"T*")
code.arg("val", f"T")
......@@ -62,6 +64,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def maximum_value_kernel(self):
code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("T")
code.arg("data", f"T*")
code.arg("val", f"T")
......@@ -723,6 +726,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def build_subm_conv_hash_table(self):
code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("TTable")
code.targ("TLayoutNPQ")
......@@ -806,6 +810,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function
def calc_subm_conv_indices_mask(self):
code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("TTable")
code.targ("TConvLocIter")
code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from cumm import tensorview as tv
import os
import torch
from typing import Dict, Optional, List, Union
from spconv.constants import AllocKeys
......@@ -100,7 +101,12 @@ def get_current_stream():
def get_arch():
arch = torch.cuda.get_device_capability()
force_arch = os.getenv("SPCONV_FORCE_CUDA_ARCH", "")
if force_arch:
force_arch = force_arch.replace(".", "")
arch = (int(force_arch[:-1]), int(force_arch[-1]))
else:
arch = torch.cuda.get_device_capability()
if not CompileInfo.arch_is_compatible(arch) and not CompileInfo.algo_can_use_ptx((0, 0), arch):
warnings.warn(
f"[WARNING]your gpu arch {arch} isn't compiled in prebuilt, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment