Commit bf34f040 authored by yan.yan's avatar yan.yan
Browse files

fix build and nvrtc problem

parent 8c25ed52
# Changelog
## [2.2.1] - 2022-9-25
### Fixed
- Fix build problem
- Fix nvrtc problem
## [2.2.0] - 2022-9-24
### Added
- Add Ampere support. faster fp16, faster tf32 and greatly faster int8 kernels in Ampere GPUs.
......
[build-system]
requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm>=0.3.1"]
requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm>=0.3.2"]
build-backend = "setuptools.build_meta"
......@@ -38,9 +38,9 @@ if cuda_ver:
cuda_ver = cuda_ver.replace(".", "") # 10.2 to 102
RELEASE_NAME += "-cu{}".format(cuda_ver)
deps = ["cumm-cu{}>=0.3.1".format(cuda_ver)]
deps = ["cumm-cu{}>=0.3.2".format(cuda_ver)]
else:
deps = ["cumm>=0.3.1"]
deps = ["cumm>=0.3.2"]
......
......@@ -17,7 +17,7 @@ import time
from enum import Enum
from threading import Lock
from typing import Dict, List, Optional, Set, Tuple, Union
from spconv.core_cc.cumm.common import CompileInfo
import numpy as np
from cumm import tensorview as tv
from cumm.conv.bases import ConvLayout, ConvLayoutType, ConvOpType
......@@ -337,9 +337,20 @@ class SimpleGemm:
ldb = b.stride[0]
ldc = c.stride[0]
if desp.supported_ldx(lda, ldb, ldc):
if arch not in COMPILED_CUDA_GEMM_ARCHS:
desp = desp.copy()
desp.is_nvrtc = True
if desp.is_nvrtc:
if not CompileInfo.algo_can_be_nvrtc_compiled(desp.min_arch):
continue
if not CompileInfo.arch_is_compiled_gemm(arch):
# use PTX of possible
if not CompileInfo.gemm_algo_can_use_ptx(desp.min_arch, arch):
if CompileInfo.algo_can_be_nvrtc_compiled(desp.min_arch):
# compiled kernel can't use PTX, for example, desp need at least sm_80 and only sm_75+PTX is compiled
# all sm_80 code of this desp is invalid, we must use nvrtc.
# only desp <= sm_75 can use virtual PTX code to generate sm_80 code.
desp = desp.copy()
desp.is_nvrtc = True
else:
continue
if SPCONV_DEBUG_NVRTC_KERNELS:
desp.is_nvrtc = True
finally_algos.append(desp)
......@@ -455,7 +466,7 @@ class SimpleGemm:
if desp.split_k_serial and hint & AlgoHint.BackwardWeight.value:
split_k_slices = max(min(32, k // 128), 1)
params = GemmParams()
if desp.is_nvrtc and str(desp) not in self.prebuilt_desp_names:
if desp.is_nvrtc or str(desp) not in self.prebuilt_desp_names:
params.nvrtc_params = self._cached_get_nvrtc_params(desp, arch)
params.a = a
params.b = b
......@@ -550,7 +561,7 @@ class SimpleGemm:
split_k_slices = profile_res.splitk
params = GemmParams()
is_not_static = str(algo_desp) not in self.prebuilt_desp_names
if algo_desp.is_nvrtc and (is_not_static or force_nvrtc):
if algo_desp.is_nvrtc or is_not_static or force_nvrtc:
params.nvrtc_params = self._cached_get_nvrtc_params(
algo_desp, profile_res.arch)
......@@ -720,9 +731,20 @@ class SimpleConv:
assert mask_width > 0
mask_width_valid = mask_width % desp.tile_shape[2] == 0
if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid:
if arch not in COMPILED_CUDA_GEMM_ARCHS:
desp = desp.copy()
desp.is_nvrtc = True
if desp.is_nvrtc:
if not CompileInfo.algo_can_be_nvrtc_compiled(desp.min_arch):
continue
if not CompileInfo.arch_is_compiled_gemm(arch):
# use PTX of possible
if not CompileInfo.gemm_algo_can_use_ptx(desp.min_arch, arch):
if CompileInfo.algo_can_be_nvrtc_compiled(desp.min_arch):
# compiled kernel can't use PTX, for example, desp need at least sm_80 and only sm_75+PTX is compiled
# all sm_80 code of this desp is invalid, we must use nvrtc.
# only desp <= sm_75 can use virtual PTX code to generate sm_80 code.
desp = desp.copy()
desp.is_nvrtc = True
else:
continue
if SPCONV_DEBUG_NVRTC_KERNELS:
desp.is_nvrtc = True
finally_algos.append(desp)
......@@ -826,7 +848,7 @@ class SimpleConv:
for desp in avail:
# for sparse conv, ndim isn't used, so we just provide a constant value.
params = ConvParams(NDIM_DONT_CARE, ConvOpTypeCpp(op_type.value))
if desp.is_nvrtc and str(desp) not in self.prebuilt_desp_names:
if desp.is_nvrtc or str(desp) not in self.prebuilt_desp_names:
params.nvrtc_params = self._cached_get_nvrtc_params(desp, arch)
params.conv_algo_desp = desp
......@@ -935,7 +957,7 @@ class SimpleConv:
params = ConvParams(NDIM_DONT_CARE, ConvOpTypeCpp(op_type_value))
is_not_static = str(
algo_desp) not in self.prebuilt_desp_names
if force_nvrtc or (algo_desp.is_nvrtc and is_not_static):
if force_nvrtc or algo_desp.is_nvrtc or is_not_static:
params.nvrtc_params = self._cached_get_nvrtc_params(
algo_desp, profile_res.arch)
params.conv_algo_desp = profile_res.algo_desp
......
......@@ -3,7 +3,6 @@ from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview.gemm import GemmAlgoDesp
from cumm.tensorview.gemm import ConvAlgoDesp
from cumm.tensorview import Tensor
from ...csrc.sparse.convops import ExternalSpconvMatmul
class GemmTuneResult:
algo_desp: GemmAlgoDesp
arch: Tuple[int, int]
......
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
class CompileInfo:
@staticmethod
def get_compiled_cuda_version() -> Tuple[int, int]: ...
@staticmethod
def get_compiled_cuda_arch() -> List[Tuple[int, int]]: ...
@staticmethod
......@@ -19,3 +21,40 @@ class CompileInfo:
arch:
"""
...
@staticmethod
def arch_is_compatible(arch: Tuple[int, int]) -> bool:
"""
Args:
arch:
"""
...
@staticmethod
def arch_is_compatible_gemm(arch: Tuple[int, int]) -> bool:
"""
Args:
arch:
"""
...
@staticmethod
def algo_can_use_ptx(min_arch: Tuple[int, int], arch: Tuple[int, int]) -> bool:
"""
Args:
min_arch:
arch:
"""
...
@staticmethod
def gemm_algo_can_use_ptx(min_arch: Tuple[int, int], arch: Tuple[int, int]) -> bool:
"""
Args:
min_arch:
arch:
"""
...
@staticmethod
def algo_can_be_nvrtc_compiled(min_arch: Tuple[int, int]) -> bool:
"""
Args:
min_arch:
"""
...
......@@ -14,13 +14,13 @@
import spconv.core_cc as _ext
from spconv.core_cc.csrc.sparse.all import SpconvOps
from spconv.core_cc.csrc.utils.boxops import BoxOps
from spconv.core_cc.cumm.common import CompileInfo
CPU_ONLY_BUILD = SpconvOps.is_cpu_only_build()
BUILD_CUMM_VERSION = SpconvOps.cumm_version()
BUILD_PCCM_VERSION = SpconvOps.pccm_version()
from spconv.core_cc.csrc.utils.boxops import BoxOps
from spconv.core_cc.cumm.common import CompileInfo
HAS_BOOST = BoxOps.has_boost()
COMPILED_CUDA_ARCHS = set(CompileInfo.get_compiled_cuda_arch())
......
......@@ -131,6 +131,8 @@ class SpconvOps(pccm.Class):
define_str = "\n".join(defines)
self.add_global_code(define_str)
self.build_meta.add_global_cflags("cl", "/DNOMINMAX")
# self.build_meta.add_global_cflags("nvcc", "-w")
# for name in dir(AllocKeys):
# if not name.startswith("__"):
# v = getattr(AllocKeys, name)
......
......@@ -550,7 +550,6 @@ class GemmTunerSimple(pccm.ParameterizedClass):
}}
// auto avail_algos = get_available_algo_str_from_arch(arch);
std::vector<tv::gemm::GemmAlgoDesp> finally_algos;
auto is_arch_compiled = CompileInfo::arch_is_compiled_gemm(arch);
static_key_t static_key = std::make_tuple(trans_a, trans_b, trans_c, int(a.dtype()),
int(b.dtype()), int(c.dtype()), shuffle_type);
if (static_key_to_desps_.find(static_key) == static_key_to_desps_.end()){{
......@@ -574,13 +573,22 @@ class GemmTunerSimple(pccm.ParameterizedClass):
auto ldb = b.stride(0);
auto ldc = c.stride(0);
if (desp.supported_ldx(lda, ldb, ldc)){{
if (!is_arch_compiled){{
auto desp2 = desp;
desp2.is_nvrtc = true;
finally_algos.push_back(desp2);
}}else{{
finally_algos.push_back(desp);
if (desp.is_nvrtc){{
if (!CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
continue;
}}
}}
if (!CompileInfo::arch_is_compiled_gemm(arch)){{
if (!CompileInfo::gemm_algo_can_use_ptx(desp.min_arch, arch)){{
if (CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
auto desp2 = desp;
desp2.is_nvrtc = true;
}}else{{
continue;
}}
}}
}}
finally_algos.push_back(desp);
}}
}}
return finally_algos;
......@@ -699,7 +707,7 @@ class GemmTunerSimple(pccm.ParameterizedClass):
for (auto& desp : avail){{
tv::gemm::GemmParams params;
if (desp.is_nvrtc && prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end()){{
if (desp.is_nvrtc || prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end()){{
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
}}
params.a = a;
......@@ -865,7 +873,7 @@ class GemmTunerSimple(pccm.ParameterizedClass):
tv::gemm::GemmParams params;
bool desp_is_static = prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end();
if (force_nvrtc || (desp.is_nvrtc && desp_is_static)){{
if (force_nvrtc || (desp.is_nvrtc || desp_is_static)){{
params.nvrtc_params = cached_get_nvrtc_params(desp, profile_res.arch, stream_int);
}}
params.a = a;
......@@ -1008,7 +1016,6 @@ class ConvTunerSimple(pccm.ParameterizedClass):
use_f32_as_accum = false;
std::vector<tv::gemm::ConvAlgoDesp> finally_algos;
auto is_arch_compiled = CompileInfo::arch_is_compiled_gemm(arch);
static_key_t static_key = std::make_tuple(
layout_i, layout_w, layout_o,
interleave_i, interleave_w, interleave_o, inp.dtype(),
......@@ -1053,13 +1060,22 @@ class ConvTunerSimple(pccm.ParameterizedClass):
mask_width_valid = mask_width % desp.tile_shape[2] == 0;
}}
if (desp.supported_ldx_conv(ldi, ldw, ldo) && mask_width_valid){{
if (!is_arch_compiled){{
auto desp2 = desp;
desp2.is_nvrtc = true;
finally_algos.push_back(desp2);
}}else{{
finally_algos.push_back(desp);
if (desp.is_nvrtc){{
if (!CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
continue;
}}
}}
if (!CompileInfo::arch_is_compiled_gemm(arch)){{
if (!CompileInfo::gemm_algo_can_use_ptx(desp.min_arch, arch)){{
if (CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
auto desp2 = desp;
desp2.is_nvrtc = true;
}}else{{
continue;
}}
}}
}}
finally_algos.push_back(desp);
}}
}}
return finally_algos;
......@@ -1134,7 +1150,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
for (auto& desp : avail){{
tv::gemm::ConvParams params({NDIM_DONT_CARE}, op_type_cpp, tv::CUDAKernelTimer(false));
if (desp.is_nvrtc && prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end()){{
if (desp.is_nvrtc || prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end()){{
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
}}
params.conv_algo_desp = desp;
......@@ -1311,7 +1327,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
auto arch = profile_res.arch;
tv::gemm::ConvParams params({NDIM_DONT_CARE}, op_type_cpp, timer);
bool desp_is_static = prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end();
if (force_nvrtc || (desp.is_nvrtc && desp_is_static)){{
if (force_nvrtc || (desp.is_nvrtc || desp_is_static)){{
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
}}
params.conv_algo_desp = desp;
......
......@@ -20,6 +20,8 @@ from spconv.cppconstants import COMPILED_CUDA_ARCHS
import sys
from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
from spconv.core_cc.csrc.sparse.convops import ExternalSpconvMatmul
from spconv.core_cc.cumm.common import CompileInfo
import warnings
import numpy as np
......@@ -93,12 +95,11 @@ def get_current_stream():
def get_arch():
arch = torch.cuda.get_device_capability()
if arch not in COMPILED_CUDA_ARCHS:
print(
if not CompileInfo.arch_is_compatible(arch) and not CompileInfo.algo_can_use_ptx((0, 0), arch):
warnings.warn(
f"[WARNING]your gpu arch {arch} isn't compiled in prebuilt, "
f"may cause invalid device function. "
f"available: {COMPILED_CUDA_ARCHS}",
file=sys.stderr)
f"may cause invalid device function error. "
f"available: {COMPILED_CUDA_ARCHS}")
return arch
......
import spconv
from spconv.core_cc.cumm.common import CompileInfo
if __name__ == "__main__":
print(CompileInfo.arch_is_compatible_gemm((9, 0)), CompileInfo.arch_is_compiled_gemm((9, 0)))
print(CompileInfo.arch_is_compatible_gemm((8, 6)), CompileInfo.arch_is_compiled_gemm((8, 6)))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment