Commit fa0c0636 authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

[major] detect msvc in setup.py

parent 0c1c2d4a
......@@ -3,6 +3,18 @@ import os
import setuptools
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
class CustomBuildExtension(BuildExtension):
def build_extensions(self):
for ext in self.extensions:
print(ext.extra_compile_args)
if not "cxx" in ext.extra_compile_args:
ext.extra_compile_args["cxx"] = []
if self.compiler.compiler_type == "msvc":
ext.extra_compile_args["cxx"] += ext.extra_compile_args["msvc"]
else:
ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"]
super().build_extensions()
if __name__ == "__main__":
fp = open("nunchaku/__version__.py", "r").read()
version = eval(fp.strip().split()[-1])
......@@ -17,7 +29,7 @@ if __name__ == "__main__":
"third_party/spdlog/include",
]
INCLUDE_DIRS = ["-I" + ROOT_DIR + "/" + dir for dir in INCLUDE_DIRS]
INCLUDE_DIRS = [ROOT_DIR + "/" + dir for dir in INCLUDE_DIRS]
DEBUG = False
......@@ -33,8 +45,10 @@ if __name__ == "__main__":
else:
return []
CXX_FLAGS = ["-DBUILD_NUNCHAKU=1", "-fvisibility=hidden", "-g", "-std=c++20", "-UNDEBUG", "-Og", *INCLUDE_DIRS]
GCC_FLAGS = ["-DENABLE_BF16=1", "-DBUILD_NUNCHAKU=1", "-fvisibility=hidden", "-g", "-std=c++20", "-UNDEBUG", "-Og"]
MSVC_FLAGS = ["/DENABLE_BF16=1", "/DBUILD_NUNCHAKU=1", "/std:c++20", "/UNDEBUG"]
NVCC_FLAGS = [
"-DENABLE_BF16=1",
"-DBUILD_NUNCHAKU=1",
"-gencode", "arch=compute_86,code=sm_86",
"-gencode", "arch=compute_89,code=sm_89",
......@@ -57,7 +71,6 @@ if __name__ == "__main__":
"--expt-extended-lambda",
"--generate-line-info",
"--ptxas-options=--allow-expensive-optimizations=true",
*INCLUDE_DIRS,
]
nunchaku_extension = CUDAExtension(
......@@ -88,7 +101,8 @@ if __name__ == "__main__":
*ncond("src/kernels/flash_attn/flash_api.cpp"),
*ncond("src/kernels/flash_attn/flash_api_adapter.cpp"),
],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
extra_compile_args={"gcc": GCC_FLAGS, "msvc": MSVC_FLAGS, "nvcc": NVCC_FLAGS},
include_dirs=INCLUDE_DIRS,
)
setuptools.setup(
......@@ -96,5 +110,5 @@ if __name__ == "__main__":
version=version,
packages=setuptools.find_packages(),
ext_modules=[nunchaku_extension],
cmdclass={"build_ext": BuildExtension},
cmdclass={"build_ext": CustomBuildExtension},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment