Commit 83c2f09d authored by Casper's avatar Casper
Browse files

Add compute capability

parent bfa5ba70
import os
import torch
from pathlib import Path
from torch.utils import cpp_extension
from setuptools import setup, find_packages
from distutils.sysconfig import get_python_lib
from torch.utils import cpp_extension, CUDA_HOME
os.environ["CC"] = "g++"
os.environ["CXX"] = "g++"
......@@ -52,6 +53,32 @@ conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/inc
if os.path.isdir(conda_cuda_include_dir):
include_dirs.append(conda_cuda_include_dir)
def check_dependencies():
if CUDA_HOME is None:
raise RuntimeError(
f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
def get_compute_capabilities():
# Collect the compute capabilities of all available GPUs.
compute_capabilities = set()
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
if major < 8:
raise RuntimeError("GPUs with compute capability less than 8.0 are not supported.")
compute_capabilities.add(major * 10 + minor)
# figure out compute capability
compute_capabilities = {80, 86, 89, 90}
capability_flags = []
for cap in compute_capabilities:
capability_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]
return capability_flags
check_dependencies()
arch_flags = get_compute_capabilities()
extensions = [
cpp_extension.CppExtension(
"awq_inference_engine",
......@@ -62,7 +89,7 @@ extensions = [
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
], extra_compile_args={
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"],
"nvcc": ["-O3", "-std=c++17"]
"nvcc": ["-O3", "-std=c++17"] + arch_flags
}
)
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment