Commit 83c2f09d authored by Casper's avatar Casper
Browse files

Add compute capability

parent bfa5ba70
import os import os
import torch
from pathlib import Path from pathlib import Path
from torch.utils import cpp_extension
from setuptools import setup, find_packages from setuptools import setup, find_packages
from distutils.sysconfig import get_python_lib from distutils.sysconfig import get_python_lib
from torch.utils import cpp_extension, CUDA_HOME
os.environ["CC"] = "g++" os.environ["CC"] = "g++"
os.environ["CXX"] = "g++" os.environ["CXX"] = "g++"
...@@ -52,6 +53,32 @@ conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/inc ...@@ -52,6 +53,32 @@ conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/inc
if os.path.isdir(conda_cuda_include_dir): if os.path.isdir(conda_cuda_include_dir):
include_dirs.append(conda_cuda_include_dir) include_dirs.append(conda_cuda_include_dir)
def check_dependencies():
if CUDA_HOME is None:
raise RuntimeError(
f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
def get_compute_capabilities():
# Collect the compute capabilities of all available GPUs.
compute_capabilities = set()
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
if major < 8:
raise RuntimeError("GPUs with compute capability less than 8.0 are not supported.")
compute_capabilities.add(major * 10 + minor)
# figure out compute capability
compute_capabilities = {80, 86, 89, 90}
capability_flags = []
for cap in compute_capabilities:
capability_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]
return capability_flags
check_dependencies()
arch_flags = get_compute_capabilities()
extensions = [ extensions = [
cpp_extension.CppExtension( cpp_extension.CppExtension(
"awq_inference_engine", "awq_inference_engine",
...@@ -62,7 +89,7 @@ extensions = [ ...@@ -62,7 +89,7 @@ extensions = [
"awq_cuda/position_embedding/pos_encoding_kernels.cu" "awq_cuda/position_embedding/pos_encoding_kernels.cu"
], extra_compile_args={ ], extra_compile_args={
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"], "cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"],
"nvcc": ["-O3", "-std=c++17"] "nvcc": ["-O3", "-std=c++17"] + arch_flags
} }
) )
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment