Commit fe03ed93 authored by yan.yan's avatar yan.yan
Browse files

use c++17 for cuda >= 11 to resolve msvc bug

parent f83bba37
...@@ -17,16 +17,25 @@ from pccm.extension import ExtCallback, PCCMBuild, PCCMExtension ...@@ -17,16 +17,25 @@ from pccm.extension import ExtCallback, PCCMBuild, PCCMExtension
from setuptools import Command, find_packages, setup from setuptools import Command, find_packages, setup
from setuptools.extension import Extension from setuptools.extension import Extension
from ccimport import compat from ccimport import compat
import subprocess
import re
# Package meta-data. # Package meta-data.
NAME = 'spconv' NAME = 'spconv'
RELEASE_NAME = NAME RELEASE_NAME = NAME
deps = ["cumm"] deps = ["cumm"]
cuda_ver = os.environ.get("CUMM_CUDA_VERSON", "") cuda_ver = os.environ.get("CUMM_CUDA_VERSON", "")
if cuda_ver: if not cuda_ver:
cuda_ver = cuda_ver.replace(".", "") # 10.2 to 102 nvcc_version = subprocess.check_output(["nvcc", "--version"
RELEASE_NAME += "-cu{}".format(cuda_ver) ]).decode("utf-8").strip()
deps = ["cumm-cu{}".format(cuda_ver)] nvcc_version_str = nvcc_version.split("\n")[3]
version_str: str = re.findall(r"release (\d+.\d+)",
nvcc_version_str)[0]
cuda_ver = cuda_ver.replace(".", "") # 10.2 to 102
RELEASE_NAME += "-cu{}".format(cuda_ver)
deps = ["cumm-cu{}".format(cuda_ver)]
DESCRIPTION = 'spatial sparse convolution' DESCRIPTION = 'spatial sparse convolution'
URL = 'https://github.com/traveller59/spconv' URL = 'https://github.com/traveller59/spconv'
EMAIL = 'yanyan.sub@outlook.com' EMAIL = 'yanyan.sub@outlook.com'
...@@ -138,10 +147,11 @@ if disable_jit is not None and disable_jit == "1": ...@@ -138,10 +147,11 @@ if disable_jit is not None and disable_jit == "1":
cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS) cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS)
cu.namespace = "cumm.gemm.main" cu.namespace = "cumm.gemm.main"
if compat.InWindows: cuda_ver_number = int(cuda_ver)
std = None if cuda_ver_number < 110:
std = "c++14"
else: else:
std = "-std=c++14" std = "c++17"
ext_modules: List[Extension] = [ ext_modules: List[Extension] = [
PCCMExtension([cu, SpconvOps()], PCCMExtension([cu, SpconvOps()],
"spconv/core_cc", "spconv/core_cc",
......
...@@ -26,15 +26,6 @@ class SpconvOps(pccm.Class): ...@@ -26,15 +26,6 @@ class SpconvOps(pccm.Class):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.ndims = [1, 2, 3, 4] self.ndims = [1, 2, 3, 4]
if compat.InWindows:
if "cl" not in self.build_meta.compiler_to_cflags:
self.build_meta.compiler_to_cflags["cl"] = []
self.build_meta.compiler_to_cflags["cl"].extend("-Xcompiler=\"/std:c++17\"")
if "nvcc" not in self.build_meta.compiler_to_cflags:
self.build_meta.compiler_to_cflags["nvcc"] = []
self.build_meta.compiler_to_cflags["nvcc"].extend("-std=c++14")
for ndim in self.ndims: for ndim in self.ndims:
p2v = Point2Voxel(dtypes.float32, ndim) p2v = Point2Voxel(dtypes.float32, ndim)
p2v_cpu = Point2VoxelCPU(dtypes.float32, ndim) p2v_cpu = Point2VoxelCPU(dtypes.float32, ndim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment