Commit b63c08aa authored by yan.yan's avatar yan.yan
Browse files

fix cuda version bug

parent 77a7981a
......@@ -177,8 +177,13 @@ if disable_jit is not None and disable_jit == "1":
cu.namespace = "cumm.gemm.main"
std = "c++17"
if cuda_ver:
cuda_ver_vec = list(map(int, cuda_ver.split(".")))
cuda_ver_tuple = (cuda_ver_vec[0], cuda_ver_vec[1])
cuda_ver_items = cuda_ver.split(".")
if len(cuda_ver_items) == 1:
cuda_ver_num = int(cuda_ver)
cuda_ver_tuple = (cuda_ver_num // 10, cuda_ver_num % 10)
else:
cuda_ver_vec = list(map(int, cuda_ver.split(".")))
cuda_ver_tuple = (cuda_ver_vec[0], cuda_ver_vec[1])
if cuda_ver_tuple[0] < 11:
std = "c++14"
else:
......
......@@ -134,8 +134,13 @@ class SpconvOps(pccm.Class):
self.build_meta.add_global_cflags("cl", "/DNOMINMAX")
cuda_ver = os.environ.get("CUMM_CUDA_VERSION", "")
if cuda_ver:
cuda_ver_vec = list(map(int, cuda_ver.split(".")))
cuda_ver_tuple = (cuda_ver_vec[0], cuda_ver_vec[1])
cuda_ver_items = cuda_ver.split(".")
if len(cuda_ver_items) == 1:
cuda_ver_num = int(cuda_ver)
cuda_ver_tuple = (cuda_ver_num // 10, cuda_ver_num % 10)
else:
cuda_ver_vec = list(map(int, cuda_ver.split(".")))
cuda_ver_tuple = (cuda_ver_vec[0], cuda_ver_vec[1])
if cuda_ver_tuple[0] < 11:
self.build_meta.add_global_cflags("nvcc", "-w")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment