Commit 1a3acf02 authored by Casper's avatar Casper
Browse files

Generalize to Linux and Windows

parent 00ec82bc
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from pathlib import Path from pathlib import Path
from setuptools import setup, find_packages from setuptools import setup, find_packages
from distutils.sysconfig import get_python_lib from distutils.sysconfig import get_python_lib
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CUDA_HOME, CUDAExtension
os.environ["CC"] = "g++" os.environ["CC"] = "g++"
os.environ["CXX"] = "g++" os.environ["CXX"] = "g++"
...@@ -19,7 +19,7 @@ common_setup_kwargs = { ...@@ -19,7 +19,7 @@ common_setup_kwargs = {
"long_description_content_type": "text/markdown", "long_description_content_type": "text/markdown",
"url": "https://github.com/casper-hansen/AutoAWQ", "url": "https://github.com/casper-hansen/AutoAWQ",
"keywords": ["awq", "autoawq", "quantization", "transformers"], "keywords": ["awq", "autoawq", "quantization", "transformers"],
"platforms": ["windows", "linux"], "platforms": ["linux", "windows"],
"classifiers": [ "classifiers": [
"Environment :: GPU :: NVIDIA CUDA :: 11.8", "Environment :: GPU :: NVIDIA CUDA :: 11.8",
"Environment :: GPU :: NVIDIA CUDA :: 12", "Environment :: GPU :: NVIDIA CUDA :: 12",
...@@ -79,6 +79,17 @@ def get_compute_capabilities(): ...@@ -79,6 +79,17 @@ def get_compute_capabilities():
check_dependencies() check_dependencies()
arch_flags = get_compute_capabilities() arch_flags = get_compute_capabilities()
if os.name == "nt":
# Relaxed args on Windows
extra_compile_args={
"nvcc": arch_flags
}
else:
extra_compile_args={
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"],
"nvcc": ["-O3", "-std=c++17"] + arch_flags
}
extensions = [ extensions = [
CUDAExtension( CUDAExtension(
"awq_inference_engine", "awq_inference_engine",
...@@ -87,7 +98,7 @@ extensions = [ ...@@ -87,7 +98,7 @@ extensions = [
"awq_cuda/quantization/gemm_cuda_gen.cu", "awq_cuda/quantization/gemm_cuda_gen.cu",
"awq_cuda/layernorm/layernorm.cu", "awq_cuda/layernorm/layernorm.cu",
"awq_cuda/position_embedding/pos_encoding_kernels.cu" "awq_cuda/position_embedding/pos_encoding_kernels.cu"
] ], extra_compile_args=extra_compile_args
) )
] ]
...@@ -103,4 +114,4 @@ setup( ...@@ -103,4 +114,4 @@ setup(
install_requires=requirements, install_requires=requirements,
include_dirs=include_dirs, include_dirs=include_dirs,
**common_setup_kwargs **common_setup_kwargs
) )
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment