Commit a2dd956c authored by one's avatar one
Browse files

Adapt spconv for DTK SIMT fallback

Add a DTK-specific kernel filter path for running spconv through the
DTK CUDA compatibility layer on BW100. The recommended `dtk_simt`
filter keeps only SIMT kernels, forces SIMT params to static codegen,
and removes Volta/Turing/Ampere TensorOp, int8, and NVRTC paths from
the active kernel set.

Add `dtk_tensorop` as a separate non-default adaptation entry point for
future Ampere TensorOp work. This keeps static non-int8 Ampere TensorOp
params while still excluding Volta/Turing, int8, and NVRTC paths.

Allow fp16 workloads to use SIMT fallback when `SPCONV_DTK_KERNEL_FILTER`
is set to `dtk_simt`. This updates both the Python tuner and generated
C++ ConvTunerSimple logic so fp16 no longer depends on currently
unsupported TensorOp paths on DTK.

Add `SPCONV_FORCE_CUDA_ARCH` to keep runtime dispatch aligned with the
compiled arch list, and keep the BW100 path explicit with `9.3`.

Adjust DTK build/runtime compatibility:
- reuse the guarded editable-install state during constants setup
- skip the Linux Thrust `-fno-gnu-unique` flag under the DTK inline-PTX
  compatibility path
- add launch bounds to helper kernels that are launched with 1024
  threads/block

This leaves full TensorOp, int8, fp8, and NVRTC support out of the
recommended DTK path. Those remain future adaptation work.
parent 263d6b47
Pipeline #3585 failed with stages
in 0 seconds
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import os
import time import time
from enum import Enum from enum import Enum
from threading import Lock from threading import Lock
...@@ -728,7 +729,8 @@ class SimpleConv: ...@@ -728,7 +729,8 @@ class SimpleConv:
if (desp.tensorop[0] > 0 and inp.dtype == tv.float32 if (desp.tensorop[0] > 0 and inp.dtype == tv.float32
and weight.dtype == tv.float32 and out.dtype == tv.float32): and weight.dtype == tv.float32 and out.dtype == tv.float32):
continue continue
if arch >= (7, 0) and is_fp16: allow_dtk_fp16_simt = os.getenv("SPCONV_DTK_KERNEL_FILTER", "").lower() == "dtk_simt"
if arch >= (7, 0) and is_fp16 and not allow_dtk_fp16_simt:
if desp.algo == GemmAlgo.Simt: if desp.algo == GemmAlgo.Simt:
continue continue
if use_f32_as_accum: if use_f32_as_accum:
......
...@@ -60,7 +60,7 @@ SPCONV_BWD_SPLITK = list(map(int, os.getenv("SPCONV_BWD_SPLITK", "1,2,4,8,16,32, ...@@ -60,7 +60,7 @@ SPCONV_BWD_SPLITK = list(map(int, os.getenv("SPCONV_BWD_SPLITK", "1,2,4,8,16,32,
SPCONV_NVRTC_MODE = NVRTCMode.ConstantMemory SPCONV_NVRTC_MODE = NVRTCMode.ConstantMemory
SPCONV_DEBUG_NVRTC_KERNELS = False SPCONV_DEBUG_NVRTC_KERNELS = False
SPCONV_DEBUG_CPP_ONLY = project_is_editable(PACKAGE_NAME) SPCONV_DEBUG_CPP_ONLY = EDITABLE_INSTALLED
class AllocKeys: class AllocKeys:
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from enum import Enum from enum import Enum
from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgoParams from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgoParams
from cumm.gemm import kernel from cumm.gemm import kernel
...@@ -1338,6 +1339,65 @@ if not SPCONV_INT8_DEBUG: ...@@ -1338,6 +1339,65 @@ if not SPCONV_INT8_DEBUG:
int8_inference=True), int8_inference=True),
]) ])
def _is_ampere_param(param):
return param.algo == GemmAlgo.Ampere
def _is_simt_param(param):
return param.algo == GemmAlgo.Simt
def _is_static_param(param):
return not getattr(param, "is_nvrtc", False)
def _is_non_int8_param(param):
return not getattr(param, "int8_inference", False)
def _is_ampere_no_int8_static_param(param):
return _is_ampere_param(param) and _is_static_param(param) and _is_non_int8_param(param)
def _filter_params(params, predicate):
return [param for param in params if predicate(param)]
def _force_static_params(params, predicate):
for param in params:
if predicate(param):
param.is_nvrtc = False
return params
def _clear_turing_volta():
global SHUFFLE_TURING_PARAMS, SHUFFLE_VOLTA_PARAMS
global IMPLGEMM_TURING_PARAMS, IMPLGEMM_VOLTA_PARAMS
SHUFFLE_TURING_PARAMS = []
SHUFFLE_VOLTA_PARAMS = []
IMPLGEMM_TURING_PARAMS = []
IMPLGEMM_VOLTA_PARAMS = []
_DTK_KERNEL_FILTER = os.getenv("SPCONV_DTK_KERNEL_FILTER", "").lower()
if _DTK_KERNEL_FILTER == "dtk_simt":
SHUFFLE_SIMT_PARAMS = _force_static_params(SHUFFLE_SIMT_PARAMS, _is_simt_param)
SHUFFLE_AMPERE_PARAMS = []
IMPLGEMM_SIMT_PARAMS = _force_static_params(IMPLGEMM_SIMT_PARAMS, _is_simt_param)
IMPLGEMM_AMPERE_PARAMS = []
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "dtk_tensorop":
SHUFFLE_SIMT_PARAMS = _force_static_params(SHUFFLE_SIMT_PARAMS, _is_simt_param)
SHUFFLE_AMPERE_PARAMS = _filter_params(SHUFFLE_AMPERE_PARAMS, _is_ampere_no_int8_static_param)
IMPLGEMM_SIMT_PARAMS = _force_static_params(IMPLGEMM_SIMT_PARAMS, _is_simt_param)
IMPLGEMM_AMPERE_PARAMS = _filter_params(IMPLGEMM_AMPERE_PARAMS, _is_ampere_no_int8_static_param)
_clear_turing_volta()
elif _DTK_KERNEL_FILTER == "":
pass
else:
raise ValueError(f"unknown SPCONV_DTK_KERNEL_FILTER: {_DTK_KERNEL_FILTER}")
ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_AMPERE_PARAMS ALL_NATIVE_PARAMS = SHUFFLE_SIMT_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_AMPERE_PARAMS
ALL_IMPGEMM_PARAMS = IMPLGEMM_SIMT_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_AMPERE_PARAMS ALL_IMPGEMM_PARAMS = IMPLGEMM_SIMT_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_AMPERE_PARAMS
...@@ -36,7 +36,7 @@ class CustomThrustLib(pccm.Class): ...@@ -36,7 +36,7 @@ class CustomThrustLib(pccm.Class):
super().__init__() super().__init__()
self.add_dependency(ThrustLib) self.add_dependency(ThrustLib)
# https://github.com/NVIDIA/thrust/issues/1401#issuecomment-806403746 # https://github.com/NVIDIA/thrust/issues/1401#issuecomment-806403746
if compat.InLinux: if compat.InLinux and os.getenv("CUMM_DTK_DISABLE_INLINE_PTX", "0") != "1":
self.build_meta.add_public_cflags("nvcc", "-Xcompiler -fno-gnu-unique", "-Xcompiler -fvisibility=hidden") self.build_meta.add_public_cflags("nvcc", "-Xcompiler -fno-gnu-unique", "-Xcompiler -fvisibility=hidden")
......
import os
from typing import Optional from typing import Optional
import pccm import pccm
from cumm.common import GemmBasicHost, NlohmannJson, TensorView from cumm.common import GemmBasicHost, NlohmannJson, TensorView
...@@ -1058,7 +1059,8 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1058,7 +1059,8 @@ class ConvTunerSimple(pccm.ParameterizedClass):
continue; continue;
}} }}
}} }}
if (arch >= std::make_tuple(7, 0) && is_fp16){{ bool allow_dtk_fp16_simt = {pccm.boolean(os.getenv("SPCONV_DTK_KERNEL_FILTER", "").lower() == "dtk_simt")};
if (arch >= std::make_tuple(7, 0) && is_fp16 && !allow_dtk_fp16_simt){{
// skip simt fp16 kernels if we have tensor core // skip simt fp16 kernels if we have tensor core
if (desp.algo == {pccm.literal(GemmAlgo.Simt.value)}){{ if (desp.algo == {pccm.literal(GemmAlgo.Simt.value)}){{
continue; continue;
......
...@@ -35,6 +35,7 @@ class CudaCommonKernel(pccm.ParameterizedClass): ...@@ -35,6 +35,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def arange_kernel(self): def arange_kernel(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("T") code.targ("T")
code.arg("data", f"T*") code.arg("data", f"T*")
code.arg("size", f"int") code.arg("size", f"int")
...@@ -48,6 +49,7 @@ class CudaCommonKernel(pccm.ParameterizedClass): ...@@ -48,6 +49,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def fill_kernel(self): def fill_kernel(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("T") code.targ("T")
code.arg("data", f"T*") code.arg("data", f"T*")
code.arg("val", f"T") code.arg("val", f"T")
...@@ -62,6 +64,7 @@ class CudaCommonKernel(pccm.ParameterizedClass): ...@@ -62,6 +64,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def maximum_value_kernel(self): def maximum_value_kernel(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("T") code.targ("T")
code.arg("data", f"T*") code.arg("data", f"T*")
code.arg("val", f"T") code.arg("val", f"T")
...@@ -723,6 +726,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -723,6 +726,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def build_subm_conv_hash_table(self): def build_subm_conv_hash_table(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("TTable") code.targ("TTable")
code.targ("TLayoutNPQ") code.targ("TLayoutNPQ")
...@@ -806,6 +810,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass): ...@@ -806,6 +810,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def calc_subm_conv_indices_mask(self): def calc_subm_conv_indices_mask(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.add_pre_attr("__launch_bounds__(1024)")
code.targ("TTable") code.targ("TTable")
code.targ("TConvLocIter") code.targ("TConvLocIter")
code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1] code.arg("loc_iter", f"TConvLocIter") # [N, ndim + 1]
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from cumm import tensorview as tv from cumm import tensorview as tv
import os
import torch import torch
from typing import Dict, Optional, List, Union from typing import Dict, Optional, List, Union
from spconv.constants import AllocKeys from spconv.constants import AllocKeys
...@@ -100,6 +101,11 @@ def get_current_stream(): ...@@ -100,6 +101,11 @@ def get_current_stream():
def get_arch(): def get_arch():
force_arch = os.getenv("SPCONV_FORCE_CUDA_ARCH", "")
if force_arch:
force_arch = force_arch.replace(".", "")
arch = (int(force_arch[:-1]), int(force_arch[-1]))
else:
arch = torch.cuda.get_device_capability() arch = torch.cuda.get_device_capability()
if not CompileInfo.arch_is_compatible(arch) and not CompileInfo.algo_can_use_ptx((0, 0), arch): if not CompileInfo.arch_is_compatible(arch) and not CompileInfo.algo_can_use_ptx((0, 0), arch):
warnings.warn( warnings.warn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment