Commit 7e361d16 authored by Casper's avatar Casper
Browse files

Add cuda_runtime in include_dirs

parent 5bc0916b
import os
import torch
from setuptools import setup, find_packages
from distutils.sysconfig import get_python_lib
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
if CUDA_HOME is None:
raise RuntimeError(
f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
def check_dependencies():
if CUDA_HOME is None:
raise RuntimeError(
f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
# Collect the compute capabilities of all available GPUs.
compute_capabilities = set()
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
if major < 8:
raise RuntimeError("GPUs with compute capability less than 8.0 are not supported.")
compute_capabilities.add(major * 10 + minor)
# Get environment variables
build_cuda_extension = os.environ.get('BUILD_CUDA_EXT', '1') == '1'
torch_is_prebuilt = os.environ.get('TORCH_IS_PREBUILT', '0') == '1'
def get_compute_capabilities():
# Collect the compute capabilities of all available GPUs.
compute_capabilities = set()
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
if major < 8:
raise RuntimeError("GPUs with compute capability less than 8.0 are not supported.")
compute_capabilities.add(major * 10 + minor)
# figure out compute capability
compute_capabilities = {80, 86, 89, 90}
capability_flags = []
for cap in compute_capabilities:
capability_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]
return capability_flags
# Define dependencies
dependencies = [
......@@ -25,29 +33,22 @@ dependencies = [
"transformers>=4.32.0",
"lm_eval", "texttable",
"toml", "attributedict",
"protobuf"
"protobuf",
"torch>=2.0.0", "torchvision"
]
if not torch_is_prebuilt:
dependencies.extend(["torch>=2.0.0", "torchvision"])
# Get environment variables
build_cuda_extension = os.environ.get('BUILD_CUDA_EXT', '1') == '1'
# Setup CUDA extension
ext_modules = []
if build_cuda_extension:
# figure out compute capability
compute_capabilities = {80, 86, 89, 90}
if torch_is_prebuilt:
compute_capabilities.update({87})
capability_flags = []
for cap in compute_capabilities:
capability_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]
# num threads
n_threads = str(min(os.cpu_count(), 8))
# final args
capability_flags = get_compute_capabilities()
cxx_args = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"]
nvcc_args = ["-O3", "-std=c++17", "--threads", n_threads] + capability_flags
......@@ -67,6 +68,13 @@ if build_cuda_extension:
)
)
# Find directories to be included in setup
include_dirs = []
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
if os.path.isdir(conda_cuda_include_dir):
include_dirs.append(conda_cuda_include_dir)
setup(
name="autoawq",
version="0.1.0",
......@@ -90,6 +98,7 @@ setup(
"Programming Language :: C++",
],
install_requires=dependencies,
include_dirs=include_dirs,
packages=find_packages(exclude=["examples*"]),
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment