Commit fe03ed93 authored by yan.yan's avatar yan.yan
Browse files

use c++17 for cuda >= 11 to resolve msvc bug

parent f83bba37
......@@ -17,16 +17,25 @@ from pccm.extension import ExtCallback, PCCMBuild, PCCMExtension
from setuptools import Command, find_packages, setup
from setuptools.extension import Extension
from ccimport import compat
import subprocess
import re
# Package meta-data.
NAME = 'spconv'
RELEASE_NAME = NAME
deps = ["cumm"]
cuda_ver = os.environ.get("CUMM_CUDA_VERSON", "")
if cuda_ver:
cuda_ver = cuda_ver.replace(".", "") # 10.2 to 102
RELEASE_NAME += "-cu{}".format(cuda_ver)
deps = ["cumm-cu{}".format(cuda_ver)]
if not cuda_ver:
nvcc_version = subprocess.check_output(["nvcc", "--version"
]).decode("utf-8").strip()
nvcc_version_str = nvcc_version.split("\n")[3]
version_str: str = re.findall(r"release (\d+.\d+)",
nvcc_version_str)[0]
cuda_ver = cuda_ver.replace(".", "") # 10.2 to 102
RELEASE_NAME += "-cu{}".format(cuda_ver)
deps = ["cumm-cu{}".format(cuda_ver)]
DESCRIPTION = 'spatial sparse convolution'
URL = 'https://github.com/traveller59/spconv'
EMAIL = 'yanyan.sub@outlook.com'
......@@ -138,10 +147,11 @@ if disable_jit is not None and disable_jit == "1":
cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS)
cu.namespace = "cumm.gemm.main"
if compat.InWindows:
std = None
cuda_ver_number = int(cuda_ver)
if cuda_ver_number < 110:
std = "c++14"
else:
std = "-std=c++14"
std = "c++17"
ext_modules: List[Extension] = [
PCCMExtension([cu, SpconvOps()],
"spconv/core_cc",
......
......@@ -26,15 +26,6 @@ class SpconvOps(pccm.Class):
def __init__(self):
super().__init__()
self.ndims = [1, 2, 3, 4]
if compat.InWindows:
if "cl" not in self.build_meta.compiler_to_cflags:
self.build_meta.compiler_to_cflags["cl"] = []
self.build_meta.compiler_to_cflags["cl"].extend("-Xcompiler=\"/std:c++17\"")
if "nvcc" not in self.build_meta.compiler_to_cflags:
self.build_meta.compiler_to_cflags["nvcc"] = []
self.build_meta.compiler_to_cflags["nvcc"].extend("-std=c++14")
for ndim in self.ndims:
p2v = Point2Voxel(dtypes.float32, ndim)
p2v_cpu = Point2VoxelCPU(dtypes.float32, ndim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment