Commit 6292da93 authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

[major] fix cutlass on msvc

parent a9f1b7af
...@@ -8,8 +8,11 @@ class CustomBuildExtension(BuildExtension): ...@@ -8,8 +8,11 @@ class CustomBuildExtension(BuildExtension):
for ext in self.extensions: for ext in self.extensions:
if not "cxx" in ext.extra_compile_args: if not "cxx" in ext.extra_compile_args:
ext.extra_compile_args["cxx"] = [] ext.extra_compile_args["cxx"] = []
if not "nvcc" in ext.extra_compile_args:
ext.extra_compile_args["nvcc"] = []
if self.compiler.compiler_type == "msvc": if self.compiler.compiler_type == "msvc":
ext.extra_compile_args["cxx"] += ext.extra_compile_args["msvc"] ext.extra_compile_args["cxx"] += ext.extra_compile_args["msvc"]
ext.extra_compile_args["nvcc"] += ext.extra_compile_args["nvcc_msvc"]
else: else:
ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"] ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"]
super().build_extensions() super().build_extensions()
...@@ -45,7 +48,7 @@ if __name__ == "__main__": ...@@ -45,7 +48,7 @@ if __name__ == "__main__":
return [] return []
GCC_FLAGS = ["-DENABLE_BF16=1", "-DBUILD_NUNCHAKU=1", "-fvisibility=hidden", "-g", "-std=c++20", "-UNDEBUG", "-Og"] GCC_FLAGS = ["-DENABLE_BF16=1", "-DBUILD_NUNCHAKU=1", "-fvisibility=hidden", "-g", "-std=c++20", "-UNDEBUG", "-Og"]
MSVC_FLAGS = ["/DENABLE_BF16=1", "/DBUILD_NUNCHAKU=1", "/std:c++20", "/UNDEBUG"] MSVC_FLAGS = ["/DENABLE_BF16=1", "/DBUILD_NUNCHAKU=1", "/std:c++20", "/UNDEBUG", "/Zc:__cplusplus"]
NVCC_FLAGS = [ NVCC_FLAGS = [
"-DENABLE_BF16=1", "-DENABLE_BF16=1",
"-DBUILD_NUNCHAKU=1", "-DBUILD_NUNCHAKU=1",
...@@ -71,6 +74,10 @@ if __name__ == "__main__": ...@@ -71,6 +74,10 @@ if __name__ == "__main__":
"--generate-line-info", "--generate-line-info",
"--ptxas-options=--allow-expensive-optimizations=true", "--ptxas-options=--allow-expensive-optimizations=true",
] ]
# https://github.com/NVIDIA/cutlass/pull/1479#issuecomment-2052300487
NVCC_MSVC_FLAGS = [
"-Xcompiler", "/Zc:__cplusplus"
]
nunchaku_extension = CUDAExtension( nunchaku_extension = CUDAExtension(
name="nunchaku._C", name="nunchaku._C",
...@@ -100,7 +107,7 @@ if __name__ == "__main__": ...@@ -100,7 +107,7 @@ if __name__ == "__main__":
*ncond("src/kernels/flash_attn/flash_api.cpp"), *ncond("src/kernels/flash_attn/flash_api.cpp"),
*ncond("src/kernels/flash_attn/flash_api_adapter.cpp"), *ncond("src/kernels/flash_attn/flash_api_adapter.cpp"),
], ],
extra_compile_args={"gcc": GCC_FLAGS, "msvc": MSVC_FLAGS, "nvcc": NVCC_FLAGS}, extra_compile_args={"gcc": GCC_FLAGS, "msvc": MSVC_FLAGS, "nvcc": NVCC_FLAGS, "nvcc_msvc": NVCC_MSVC_FLAGS},
include_dirs=INCLUDE_DIRS, include_dirs=INCLUDE_DIRS,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment