Commit 6292da93 authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

[major] fix cutlass on msvc

parent a9f1b7af
......@@ -8,8 +8,11 @@ class CustomBuildExtension(BuildExtension):
for ext in self.extensions:
if not "cxx" in ext.extra_compile_args:
ext.extra_compile_args["cxx"] = []
if not "nvcc" in ext.extra_compile_args:
ext.extra_compile_args["nvcc"] = []
if self.compiler.compiler_type == "msvc":
ext.extra_compile_args["cxx"] += ext.extra_compile_args["msvc"]
ext.extra_compile_args["nvcc"] += ext.extra_compile_args["nvcc_msvc"]
else:
ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"]
super().build_extensions()
......@@ -45,7 +48,7 @@ if __name__ == "__main__":
return []
GCC_FLAGS = ["-DENABLE_BF16=1", "-DBUILD_NUNCHAKU=1", "-fvisibility=hidden", "-g", "-std=c++20", "-UNDEBUG", "-Og"]
MSVC_FLAGS = ["/DENABLE_BF16=1", "/DBUILD_NUNCHAKU=1", "/std:c++20", "/UNDEBUG"]
MSVC_FLAGS = ["/DENABLE_BF16=1", "/DBUILD_NUNCHAKU=1", "/std:c++20", "/UNDEBUG", "/Zc:__cplusplus"]
NVCC_FLAGS = [
"-DENABLE_BF16=1",
"-DBUILD_NUNCHAKU=1",
......@@ -71,6 +74,10 @@ if __name__ == "__main__":
"--generate-line-info",
"--ptxas-options=--allow-expensive-optimizations=true",
]
# https://github.com/NVIDIA/cutlass/pull/1479#issuecomment-2052300487
NVCC_MSVC_FLAGS = [
"-Xcompiler", "/Zc:__cplusplus"
]
nunchaku_extension = CUDAExtension(
name="nunchaku._C",
......@@ -100,7 +107,7 @@ if __name__ == "__main__":
*ncond("src/kernels/flash_attn/flash_api.cpp"),
*ncond("src/kernels/flash_attn/flash_api_adapter.cpp"),
],
extra_compile_args={"gcc": GCC_FLAGS, "msvc": MSVC_FLAGS, "nvcc": NVCC_FLAGS},
extra_compile_args={"gcc": GCC_FLAGS, "msvc": MSVC_FLAGS, "nvcc": NVCC_FLAGS, "nvcc_msvc": NVCC_MSVC_FLAGS},
include_dirs=INCLUDE_DIRS,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment