Unverified Commit e6398f88 authored by Casper's avatar Casper Committed by GitHub
Browse files

Fix windows build + Bump version (#3)

* Attempt to fix Windows build

* Revert build

* Bump to 0.0.2
parent fc700a82
......@@ -7,7 +7,7 @@ from torch.utils.cpp_extension import BuildExtension, CUDA_HOME, CUDAExtension
os.environ["CC"] = "g++"
os.environ["CXX"] = "g++"
AUTOAWQ_KERNELS_VERSION = "0.0.1"
AUTOAWQ_KERNELS_VERSION = "0.0.2"
PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1"
if not PYPI_BUILD:
......@@ -105,6 +105,7 @@ def get_compute_capabilities():
check_dependencies()
extra_link_args = []
include_dirs = get_include_dirs()
generator_flags = get_generator_flag()
arch_flags = get_compute_capabilities()
......@@ -117,6 +118,9 @@ if os.name == "nt":
extra_compile_args = {"nvcc": arch_flags}
else:
extra_compile_args = {}
cuda_path = os.environ.get("CUDA_PATH", None)
extra_link_args = ["-L", f"{cuda_path}/lib/x64/cublas.lib"]
else:
extra_compile_args = {
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"],
......@@ -151,6 +155,7 @@ extensions = [
extra_compile_args=extra_compile_args,
)
]
extensions.append(
CUDAExtension(
"exllama_kernels",
......@@ -162,6 +167,7 @@ extensions.append(
"awq_ext/exllama/cuda_func/q4_matrix.cu",
],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
)
extensions.append(
......@@ -173,6 +179,7 @@ extensions.append(
"awq_ext/exllamav2/cuda/q_gemm.cu",
],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment