Commit fa0c0636 authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

[major] detect msvc in setup.py

parent 0c1c2d4a
...@@ -3,6 +3,18 @@ import os ...@@ -3,6 +3,18 @@ import os
import setuptools import setuptools
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
class CustomBuildExtension(BuildExtension):
def build_extensions(self):
for ext in self.extensions:
print(ext.extra_compile_args)
if not "cxx" in ext.extra_compile_args:
ext.extra_compile_args["cxx"] = []
if self.compiler.compiler_type == "msvc":
ext.extra_compile_args["cxx"] += ext.extra_compile_args["msvc"]
else:
ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"]
super().build_extensions()
if __name__ == "__main__": if __name__ == "__main__":
fp = open("nunchaku/__version__.py", "r").read() fp = open("nunchaku/__version__.py", "r").read()
version = eval(fp.strip().split()[-1]) version = eval(fp.strip().split()[-1])
...@@ -17,7 +29,7 @@ if __name__ == "__main__": ...@@ -17,7 +29,7 @@ if __name__ == "__main__":
"third_party/spdlog/include", "third_party/spdlog/include",
] ]
INCLUDE_DIRS = ["-I" + ROOT_DIR + "/" + dir for dir in INCLUDE_DIRS] INCLUDE_DIRS = [ROOT_DIR + "/" + dir for dir in INCLUDE_DIRS]
DEBUG = False DEBUG = False
...@@ -33,8 +45,10 @@ if __name__ == "__main__": ...@@ -33,8 +45,10 @@ if __name__ == "__main__":
else: else:
return [] return []
CXX_FLAGS = ["-DBUILD_NUNCHAKU=1", "-fvisibility=hidden", "-g", "-std=c++20", "-UNDEBUG", "-Og", *INCLUDE_DIRS] GCC_FLAGS = ["-DENABLE_BF16=1", "-DBUILD_NUNCHAKU=1", "-fvisibility=hidden", "-g", "-std=c++20", "-UNDEBUG", "-Og"]
MSVC_FLAGS = ["/DENABLE_BF16=1", "/DBUILD_NUNCHAKU=1", "/std:c++20", "/UNDEBUG"]
NVCC_FLAGS = [ NVCC_FLAGS = [
"-DENABLE_BF16=1",
"-DBUILD_NUNCHAKU=1", "-DBUILD_NUNCHAKU=1",
"-gencode", "arch=compute_86,code=sm_86", "-gencode", "arch=compute_86,code=sm_86",
"-gencode", "arch=compute_89,code=sm_89", "-gencode", "arch=compute_89,code=sm_89",
...@@ -57,7 +71,6 @@ if __name__ == "__main__": ...@@ -57,7 +71,6 @@ if __name__ == "__main__":
"--expt-extended-lambda", "--expt-extended-lambda",
"--generate-line-info", "--generate-line-info",
"--ptxas-options=--allow-expensive-optimizations=true", "--ptxas-options=--allow-expensive-optimizations=true",
*INCLUDE_DIRS,
] ]
nunchaku_extension = CUDAExtension( nunchaku_extension = CUDAExtension(
...@@ -88,7 +101,8 @@ if __name__ == "__main__": ...@@ -88,7 +101,8 @@ if __name__ == "__main__":
*ncond("src/kernels/flash_attn/flash_api.cpp"), *ncond("src/kernels/flash_attn/flash_api.cpp"),
*ncond("src/kernels/flash_attn/flash_api_adapter.cpp"), *ncond("src/kernels/flash_attn/flash_api_adapter.cpp"),
], ],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, extra_compile_args={"gcc": GCC_FLAGS, "msvc": MSVC_FLAGS, "nvcc": NVCC_FLAGS},
include_dirs=INCLUDE_DIRS,
) )
setuptools.setup( setuptools.setup(
...@@ -96,5 +110,5 @@ if __name__ == "__main__": ...@@ -96,5 +110,5 @@ if __name__ == "__main__":
version=version, version=version,
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
ext_modules=[nunchaku_extension], ext_modules=[nunchaku_extension],
cmdclass={"build_ext": BuildExtension}, cmdclass={"build_ext": CustomBuildExtension},
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment