You need to sign in or sign up before continuing.
Commit bf34f040 authored by yan.yan's avatar yan.yan
Browse files

fix build and nvrtc problem

parent 8c25ed52
# Changelog # Changelog
## [2.2.1] - 2022-9-25
### Fixed
- Fix build problem
- Fix nvrtc problem
## [2.2.0] - 2022-9-24 ## [2.2.0] - 2022-9-24
### Added ### Added
- Add Ampere support. faster fp16, faster tf32 and greatly faster int8 kernels in Ampere GPUs. - Add Ampere support. faster fp16, faster tf32 and greatly faster int8 kernels in Ampere GPUs.
......
[build-system] [build-system]
requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm>=0.3.1"] requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm>=0.3.2"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
...@@ -38,9 +38,9 @@ if cuda_ver: ...@@ -38,9 +38,9 @@ if cuda_ver:
cuda_ver = cuda_ver.replace(".", "") # 10.2 to 102 cuda_ver = cuda_ver.replace(".", "") # 10.2 to 102
RELEASE_NAME += "-cu{}".format(cuda_ver) RELEASE_NAME += "-cu{}".format(cuda_ver)
deps = ["cumm-cu{}>=0.3.1".format(cuda_ver)] deps = ["cumm-cu{}>=0.3.2".format(cuda_ver)]
else: else:
deps = ["cumm>=0.3.1"] deps = ["cumm>=0.3.2"]
......
...@@ -17,7 +17,7 @@ import time ...@@ -17,7 +17,7 @@ import time
from enum import Enum from enum import Enum
from threading import Lock from threading import Lock
from typing import Dict, List, Optional, Set, Tuple, Union from typing import Dict, List, Optional, Set, Tuple, Union
from spconv.core_cc.cumm.common import CompileInfo
import numpy as np import numpy as np
from cumm import tensorview as tv from cumm import tensorview as tv
from cumm.conv.bases import ConvLayout, ConvLayoutType, ConvOpType from cumm.conv.bases import ConvLayout, ConvLayoutType, ConvOpType
...@@ -337,9 +337,20 @@ class SimpleGemm: ...@@ -337,9 +337,20 @@ class SimpleGemm:
ldb = b.stride[0] ldb = b.stride[0]
ldc = c.stride[0] ldc = c.stride[0]
if desp.supported_ldx(lda, ldb, ldc): if desp.supported_ldx(lda, ldb, ldc):
if arch not in COMPILED_CUDA_GEMM_ARCHS: if desp.is_nvrtc:
desp = desp.copy() if not CompileInfo.algo_can_be_nvrtc_compiled(desp.min_arch):
desp.is_nvrtc = True continue
if not CompileInfo.arch_is_compiled_gemm(arch):
# use PTX of possible
if not CompileInfo.gemm_algo_can_use_ptx(desp.min_arch, arch):
if CompileInfo.algo_can_be_nvrtc_compiled(desp.min_arch):
# compiled kernel can't use PTX, for example, desp need at least sm_80 and only sm_75+PTX is compiled
# all sm_80 code of this desp is invalid, we must use nvrtc.
# only desp <= sm_75 can use virtual PTX code to generate sm_80 code.
desp = desp.copy()
desp.is_nvrtc = True
else:
continue
if SPCONV_DEBUG_NVRTC_KERNELS: if SPCONV_DEBUG_NVRTC_KERNELS:
desp.is_nvrtc = True desp.is_nvrtc = True
finally_algos.append(desp) finally_algos.append(desp)
...@@ -455,7 +466,7 @@ class SimpleGemm: ...@@ -455,7 +466,7 @@ class SimpleGemm:
if desp.split_k_serial and hint & AlgoHint.BackwardWeight.value: if desp.split_k_serial and hint & AlgoHint.BackwardWeight.value:
split_k_slices = max(min(32, k // 128), 1) split_k_slices = max(min(32, k // 128), 1)
params = GemmParams() params = GemmParams()
if desp.is_nvrtc and str(desp) not in self.prebuilt_desp_names: if desp.is_nvrtc or str(desp) not in self.prebuilt_desp_names:
params.nvrtc_params = self._cached_get_nvrtc_params(desp, arch) params.nvrtc_params = self._cached_get_nvrtc_params(desp, arch)
params.a = a params.a = a
params.b = b params.b = b
...@@ -550,7 +561,7 @@ class SimpleGemm: ...@@ -550,7 +561,7 @@ class SimpleGemm:
split_k_slices = profile_res.splitk split_k_slices = profile_res.splitk
params = GemmParams() params = GemmParams()
is_not_static = str(algo_desp) not in self.prebuilt_desp_names is_not_static = str(algo_desp) not in self.prebuilt_desp_names
if algo_desp.is_nvrtc and (is_not_static or force_nvrtc): if algo_desp.is_nvrtc or is_not_static or force_nvrtc:
params.nvrtc_params = self._cached_get_nvrtc_params( params.nvrtc_params = self._cached_get_nvrtc_params(
algo_desp, profile_res.arch) algo_desp, profile_res.arch)
...@@ -720,9 +731,20 @@ class SimpleConv: ...@@ -720,9 +731,20 @@ class SimpleConv:
assert mask_width > 0 assert mask_width > 0
mask_width_valid = mask_width % desp.tile_shape[2] == 0 mask_width_valid = mask_width % desp.tile_shape[2] == 0
if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid: if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid:
if arch not in COMPILED_CUDA_GEMM_ARCHS: if desp.is_nvrtc:
desp = desp.copy() if not CompileInfo.algo_can_be_nvrtc_compiled(desp.min_arch):
desp.is_nvrtc = True continue
if not CompileInfo.arch_is_compiled_gemm(arch):
# use PTX of possible
if not CompileInfo.gemm_algo_can_use_ptx(desp.min_arch, arch):
if CompileInfo.algo_can_be_nvrtc_compiled(desp.min_arch):
# compiled kernel can't use PTX, for example, desp need at least sm_80 and only sm_75+PTX is compiled
# all sm_80 code of this desp is invalid, we must use nvrtc.
# only desp <= sm_75 can use virtual PTX code to generate sm_80 code.
desp = desp.copy()
desp.is_nvrtc = True
else:
continue
if SPCONV_DEBUG_NVRTC_KERNELS: if SPCONV_DEBUG_NVRTC_KERNELS:
desp.is_nvrtc = True desp.is_nvrtc = True
finally_algos.append(desp) finally_algos.append(desp)
...@@ -826,7 +848,7 @@ class SimpleConv: ...@@ -826,7 +848,7 @@ class SimpleConv:
for desp in avail: for desp in avail:
# for sparse conv, ndim isn't used, so we just provide a constant value. # for sparse conv, ndim isn't used, so we just provide a constant value.
params = ConvParams(NDIM_DONT_CARE, ConvOpTypeCpp(op_type.value)) params = ConvParams(NDIM_DONT_CARE, ConvOpTypeCpp(op_type.value))
if desp.is_nvrtc and str(desp) not in self.prebuilt_desp_names: if desp.is_nvrtc or str(desp) not in self.prebuilt_desp_names:
params.nvrtc_params = self._cached_get_nvrtc_params(desp, arch) params.nvrtc_params = self._cached_get_nvrtc_params(desp, arch)
params.conv_algo_desp = desp params.conv_algo_desp = desp
...@@ -935,7 +957,7 @@ class SimpleConv: ...@@ -935,7 +957,7 @@ class SimpleConv:
params = ConvParams(NDIM_DONT_CARE, ConvOpTypeCpp(op_type_value)) params = ConvParams(NDIM_DONT_CARE, ConvOpTypeCpp(op_type_value))
is_not_static = str( is_not_static = str(
algo_desp) not in self.prebuilt_desp_names algo_desp) not in self.prebuilt_desp_names
if force_nvrtc or (algo_desp.is_nvrtc and is_not_static): if force_nvrtc or algo_desp.is_nvrtc or is_not_static:
params.nvrtc_params = self._cached_get_nvrtc_params( params.nvrtc_params = self._cached_get_nvrtc_params(
algo_desp, profile_res.arch) algo_desp, profile_res.arch)
params.conv_algo_desp = profile_res.algo_desp params.conv_algo_desp = profile_res.algo_desp
......
...@@ -3,7 +3,6 @@ from pccm.stubs import EnumValue, EnumClassValue ...@@ -3,7 +3,6 @@ from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview.gemm import GemmAlgoDesp from cumm.tensorview.gemm import GemmAlgoDesp
from cumm.tensorview.gemm import ConvAlgoDesp from cumm.tensorview.gemm import ConvAlgoDesp
from cumm.tensorview import Tensor from cumm.tensorview import Tensor
from ...csrc.sparse.convops import ExternalSpconvMatmul
class GemmTuneResult: class GemmTuneResult:
algo_desp: GemmAlgoDesp algo_desp: GemmAlgoDesp
arch: Tuple[int, int] arch: Tuple[int, int]
......
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue from pccm.stubs import EnumValue, EnumClassValue
class CompileInfo: class CompileInfo:
@staticmethod
def get_compiled_cuda_version() -> Tuple[int, int]: ...
@staticmethod @staticmethod
def get_compiled_cuda_arch() -> List[Tuple[int, int]]: ... def get_compiled_cuda_arch() -> List[Tuple[int, int]]: ...
@staticmethod @staticmethod
...@@ -19,3 +21,40 @@ class CompileInfo: ...@@ -19,3 +21,40 @@ class CompileInfo:
arch: arch:
""" """
... ...
@staticmethod
def arch_is_compatible(arch: Tuple[int, int]) -> bool:
"""
Args:
arch:
"""
...
@staticmethod
def arch_is_compatible_gemm(arch: Tuple[int, int]) -> bool:
"""
Args:
arch:
"""
...
@staticmethod
def algo_can_use_ptx(min_arch: Tuple[int, int], arch: Tuple[int, int]) -> bool:
"""
Args:
min_arch:
arch:
"""
...
@staticmethod
def gemm_algo_can_use_ptx(min_arch: Tuple[int, int], arch: Tuple[int, int]) -> bool:
"""
Args:
min_arch:
arch:
"""
...
@staticmethod
def algo_can_be_nvrtc_compiled(min_arch: Tuple[int, int]) -> bool:
"""
Args:
min_arch:
"""
...
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
import spconv.core_cc as _ext import spconv.core_cc as _ext
from spconv.core_cc.csrc.sparse.all import SpconvOps from spconv.core_cc.csrc.sparse.all import SpconvOps
from spconv.core_cc.csrc.utils.boxops import BoxOps
from spconv.core_cc.cumm.common import CompileInfo
CPU_ONLY_BUILD = SpconvOps.is_cpu_only_build() CPU_ONLY_BUILD = SpconvOps.is_cpu_only_build()
BUILD_CUMM_VERSION = SpconvOps.cumm_version() BUILD_CUMM_VERSION = SpconvOps.cumm_version()
BUILD_PCCM_VERSION = SpconvOps.pccm_version() BUILD_PCCM_VERSION = SpconvOps.pccm_version()
from spconv.core_cc.csrc.utils.boxops import BoxOps
from spconv.core_cc.cumm.common import CompileInfo
HAS_BOOST = BoxOps.has_boost() HAS_BOOST = BoxOps.has_boost()
COMPILED_CUDA_ARCHS = set(CompileInfo.get_compiled_cuda_arch()) COMPILED_CUDA_ARCHS = set(CompileInfo.get_compiled_cuda_arch())
......
...@@ -131,6 +131,8 @@ class SpconvOps(pccm.Class): ...@@ -131,6 +131,8 @@ class SpconvOps(pccm.Class):
define_str = "\n".join(defines) define_str = "\n".join(defines)
self.add_global_code(define_str) self.add_global_code(define_str)
self.build_meta.add_global_cflags("cl", "/DNOMINMAX") self.build_meta.add_global_cflags("cl", "/DNOMINMAX")
# self.build_meta.add_global_cflags("nvcc", "-w")
# for name in dir(AllocKeys): # for name in dir(AllocKeys):
# if not name.startswith("__"): # if not name.startswith("__"):
# v = getattr(AllocKeys, name) # v = getattr(AllocKeys, name)
......
...@@ -550,7 +550,6 @@ class GemmTunerSimple(pccm.ParameterizedClass): ...@@ -550,7 +550,6 @@ class GemmTunerSimple(pccm.ParameterizedClass):
}} }}
// auto avail_algos = get_available_algo_str_from_arch(arch); // auto avail_algos = get_available_algo_str_from_arch(arch);
std::vector<tv::gemm::GemmAlgoDesp> finally_algos; std::vector<tv::gemm::GemmAlgoDesp> finally_algos;
auto is_arch_compiled = CompileInfo::arch_is_compiled_gemm(arch);
static_key_t static_key = std::make_tuple(trans_a, trans_b, trans_c, int(a.dtype()), static_key_t static_key = std::make_tuple(trans_a, trans_b, trans_c, int(a.dtype()),
int(b.dtype()), int(c.dtype()), shuffle_type); int(b.dtype()), int(c.dtype()), shuffle_type);
if (static_key_to_desps_.find(static_key) == static_key_to_desps_.end()){{ if (static_key_to_desps_.find(static_key) == static_key_to_desps_.end()){{
...@@ -574,13 +573,22 @@ class GemmTunerSimple(pccm.ParameterizedClass): ...@@ -574,13 +573,22 @@ class GemmTunerSimple(pccm.ParameterizedClass):
auto ldb = b.stride(0); auto ldb = b.stride(0);
auto ldc = c.stride(0); auto ldc = c.stride(0);
if (desp.supported_ldx(lda, ldb, ldc)){{ if (desp.supported_ldx(lda, ldb, ldc)){{
if (!is_arch_compiled){{ if (desp.is_nvrtc){{
auto desp2 = desp; if (!CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
desp2.is_nvrtc = true; continue;
finally_algos.push_back(desp2); }}
}}else{{ }}
finally_algos.push_back(desp); if (!CompileInfo::arch_is_compiled_gemm(arch)){{
if (!CompileInfo::gemm_algo_can_use_ptx(desp.min_arch, arch)){{
if (CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
auto desp2 = desp;
desp2.is_nvrtc = true;
}}else{{
continue;
}}
}}
}} }}
finally_algos.push_back(desp);
}} }}
}} }}
return finally_algos; return finally_algos;
...@@ -699,7 +707,7 @@ class GemmTunerSimple(pccm.ParameterizedClass): ...@@ -699,7 +707,7 @@ class GemmTunerSimple(pccm.ParameterizedClass):
for (auto& desp : avail){{ for (auto& desp : avail){{
tv::gemm::GemmParams params; tv::gemm::GemmParams params;
if (desp.is_nvrtc && prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end()){{ if (desp.is_nvrtc || prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end()){{
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int); params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
}} }}
params.a = a; params.a = a;
...@@ -865,7 +873,7 @@ class GemmTunerSimple(pccm.ParameterizedClass): ...@@ -865,7 +873,7 @@ class GemmTunerSimple(pccm.ParameterizedClass):
tv::gemm::GemmParams params; tv::gemm::GemmParams params;
bool desp_is_static = prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end(); bool desp_is_static = prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end();
if (force_nvrtc || (desp.is_nvrtc && desp_is_static)){{ if (force_nvrtc || (desp.is_nvrtc || desp_is_static)){{
params.nvrtc_params = cached_get_nvrtc_params(desp, profile_res.arch, stream_int); params.nvrtc_params = cached_get_nvrtc_params(desp, profile_res.arch, stream_int);
}} }}
params.a = a; params.a = a;
...@@ -1008,7 +1016,6 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1008,7 +1016,6 @@ class ConvTunerSimple(pccm.ParameterizedClass):
use_f32_as_accum = false; use_f32_as_accum = false;
std::vector<tv::gemm::ConvAlgoDesp> finally_algos; std::vector<tv::gemm::ConvAlgoDesp> finally_algos;
auto is_arch_compiled = CompileInfo::arch_is_compiled_gemm(arch);
static_key_t static_key = std::make_tuple( static_key_t static_key = std::make_tuple(
layout_i, layout_w, layout_o, layout_i, layout_w, layout_o,
interleave_i, interleave_w, interleave_o, inp.dtype(), interleave_i, interleave_w, interleave_o, inp.dtype(),
...@@ -1053,13 +1060,22 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1053,13 +1060,22 @@ class ConvTunerSimple(pccm.ParameterizedClass):
mask_width_valid = mask_width % desp.tile_shape[2] == 0; mask_width_valid = mask_width % desp.tile_shape[2] == 0;
}} }}
if (desp.supported_ldx_conv(ldi, ldw, ldo) && mask_width_valid){{ if (desp.supported_ldx_conv(ldi, ldw, ldo) && mask_width_valid){{
if (!is_arch_compiled){{ if (desp.is_nvrtc){{
auto desp2 = desp; if (!CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
desp2.is_nvrtc = true; continue;
finally_algos.push_back(desp2); }}
}}else{{ }}
finally_algos.push_back(desp); if (!CompileInfo::arch_is_compiled_gemm(arch)){{
if (!CompileInfo::gemm_algo_can_use_ptx(desp.min_arch, arch)){{
if (CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
auto desp2 = desp;
desp2.is_nvrtc = true;
}}else{{
continue;
}}
}}
}} }}
finally_algos.push_back(desp);
}} }}
}} }}
return finally_algos; return finally_algos;
...@@ -1134,7 +1150,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1134,7 +1150,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type); tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
for (auto& desp : avail){{ for (auto& desp : avail){{
tv::gemm::ConvParams params({NDIM_DONT_CARE}, op_type_cpp, tv::CUDAKernelTimer(false)); tv::gemm::ConvParams params({NDIM_DONT_CARE}, op_type_cpp, tv::CUDAKernelTimer(false));
if (desp.is_nvrtc && prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end()){{ if (desp.is_nvrtc || prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end()){{
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int); params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
}} }}
params.conv_algo_desp = desp; params.conv_algo_desp = desp;
...@@ -1311,7 +1327,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1311,7 +1327,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
auto arch = profile_res.arch; auto arch = profile_res.arch;
tv::gemm::ConvParams params({NDIM_DONT_CARE}, op_type_cpp, timer); tv::gemm::ConvParams params({NDIM_DONT_CARE}, op_type_cpp, timer);
bool desp_is_static = prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end(); bool desp_is_static = prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end();
if (force_nvrtc || (desp.is_nvrtc && desp_is_static)){{ if (force_nvrtc || (desp.is_nvrtc || desp_is_static)){{
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int); params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
}} }}
params.conv_algo_desp = desp; params.conv_algo_desp = desp;
......
...@@ -20,6 +20,8 @@ from spconv.cppconstants import COMPILED_CUDA_ARCHS ...@@ -20,6 +20,8 @@ from spconv.cppconstants import COMPILED_CUDA_ARCHS
import sys import sys
from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
from spconv.core_cc.csrc.sparse.convops import ExternalSpconvMatmul from spconv.core_cc.csrc.sparse.convops import ExternalSpconvMatmul
from spconv.core_cc.cumm.common import CompileInfo
import warnings
import numpy as np import numpy as np
...@@ -93,12 +95,11 @@ def get_current_stream(): ...@@ -93,12 +95,11 @@ def get_current_stream():
def get_arch(): def get_arch():
arch = torch.cuda.get_device_capability() arch = torch.cuda.get_device_capability()
if arch not in COMPILED_CUDA_ARCHS: if not CompileInfo.arch_is_compatible(arch) and not CompileInfo.algo_can_use_ptx((0, 0), arch):
print( warnings.warn(
f"[WARNING]your gpu arch {arch} isn't compiled in prebuilt, " f"[WARNING]your gpu arch {arch} isn't compiled in prebuilt, "
f"may cause invalid device function. " f"may cause invalid device function error. "
f"available: {COMPILED_CUDA_ARCHS}", f"available: {COMPILED_CUDA_ARCHS}")
file=sys.stderr)
return arch return arch
......
import spconv import spconv
from spconv.core_cc.cumm.common import CompileInfo
if __name__ == "__main__":
print(CompileInfo.arch_is_compatible_gemm((9, 0)), CompileInfo.arch_is_compiled_gemm((9, 0)))
print(CompileInfo.arch_is_compatible_gemm((8, 6)), CompileInfo.arch_is_compiled_gemm((8, 6)))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment