Commit 7e361d16 authored by Casper's avatar Casper
Browse files

Add cuda_runtime in include_dirs

parent 5bc0916b
import os import os
import torch import torch
from setuptools import setup, find_packages from setuptools import setup, find_packages
from distutils.sysconfig import get_python_lib
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
if CUDA_HOME is None: def check_dependencies():
if CUDA_HOME is None:
raise RuntimeError( raise RuntimeError(
f"Cannot find CUDA_HOME. CUDA must be available to build the package.") f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
# Collect the compute capabilities of all available GPUs. def get_compute_capabilities():
compute_capabilities = set() # Collect the compute capabilities of all available GPUs.
for i in range(torch.cuda.device_count()): compute_capabilities = set()
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i) major, minor = torch.cuda.get_device_capability(i)
if major < 8: if major < 8:
raise RuntimeError("GPUs with compute capability less than 8.0 are not supported.") raise RuntimeError("GPUs with compute capability less than 8.0 are not supported.")
compute_capabilities.add(major * 10 + minor) compute_capabilities.add(major * 10 + minor)
# Get environment variables # figure out compute capability
build_cuda_extension = os.environ.get('BUILD_CUDA_EXT', '1') == '1' compute_capabilities = {80, 86, 89, 90}
torch_is_prebuilt = os.environ.get('TORCH_IS_PREBUILT', '0') == '1'
capability_flags = []
for cap in compute_capabilities:
capability_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]
return capability_flags
# Define dependencies # Define dependencies
dependencies = [ dependencies = [
...@@ -25,29 +33,22 @@ dependencies = [ ...@@ -25,29 +33,22 @@ dependencies = [
"transformers>=4.32.0", "transformers>=4.32.0",
"lm_eval", "texttable", "lm_eval", "texttable",
"toml", "attributedict", "toml", "attributedict",
"protobuf" "protobuf",
"torch>=2.0.0", "torchvision"
] ]
if not torch_is_prebuilt: # Get environment variables
dependencies.extend(["torch>=2.0.0", "torchvision"]) build_cuda_extension = os.environ.get('BUILD_CUDA_EXT', '1') == '1'
# Setup CUDA extension # Setup CUDA extension
ext_modules = [] ext_modules = []
if build_cuda_extension: if build_cuda_extension:
# figure out compute capability
compute_capabilities = {80, 86, 89, 90}
if torch_is_prebuilt:
compute_capabilities.update({87})
capability_flags = []
for cap in compute_capabilities:
capability_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]
# num threads # num threads
n_threads = str(min(os.cpu_count(), 8)) n_threads = str(min(os.cpu_count(), 8))
# final args # final args
capability_flags = get_compute_capabilities()
cxx_args = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"] cxx_args = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"]
nvcc_args = ["-O3", "-std=c++17", "--threads", n_threads] + capability_flags nvcc_args = ["-O3", "-std=c++17", "--threads", n_threads] + capability_flags
...@@ -67,6 +68,13 @@ if build_cuda_extension: ...@@ -67,6 +68,13 @@ if build_cuda_extension:
) )
) )
# Find directories to be included in setup
include_dirs = []
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
if os.path.isdir(conda_cuda_include_dir):
include_dirs.append(conda_cuda_include_dir)
setup( setup(
name="autoawq", name="autoawq",
version="0.1.0", version="0.1.0",
...@@ -90,6 +98,7 @@ setup( ...@@ -90,6 +98,7 @@ setup(
"Programming Language :: C++", "Programming Language :: C++",
], ],
install_requires=dependencies, install_requires=dependencies,
include_dirs=include_dirs,
packages=find_packages(exclude=["examples*"]), packages=find_packages(exclude=["examples*"]),
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension} cmdclass={"build_ext": BuildExtension}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment