setup.py 4.51 KB
Newer Older
qwopqwop200's avatar
qwopqwop200 committed
1
2
import os
import torch
3
import platform
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
4
5
import requests
import importlib.util
qwopqwop200's avatar
qwopqwop200 committed
6
7
8
from pathlib import Path
from setuptools import setup, find_packages

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

def get_latest_kernels_version(repo):
    """
    Get the latest version of the kernels from the github repo.
    """
    response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest")
    data = response.json()
    tag_name = data["tag_name"]
    version = tag_name.replace("v", "")
    return version


def get_kernels_whl_url(
    gpu_system_version,
    release_version,
    python_version,
    platform,
    architecture,
):
    """
    Get the url for the kernels wheel file.
    """
    return f"https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v{release_version}/autoawq_kernels-{release_version}+{gpu_system_version}-cp{python_version}-cp{python_version}-{platform}_{architecture}.whl"


Casper's avatar
Casper committed
34
AUTOAWQ_VERSION = "0.1.8"
Casper's avatar
Casper committed
35
PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1"
qwopqwop200's avatar
qwopqwop200 committed
36

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
CUDA_VERSION = os.getenv("CUDA_VERSION", None) or torch.version.cuda
if CUDA_VERSION:
    CUDA_VERSION = "".join(CUDA_VERSION.split("."))[:3]

ROCM_VERSION = os.getenv("ROCM_VERSION", None) or torch.version.hip
if ROCM_VERSION:
    if ROCM_VERSION.startswith("5.6"):
        ROCM_VERSION = "5.6.1"
    elif ROCM_VERSION.startswith("5.7"):
        ROCM_VERSION = "5.7.1"

    ROCM_VERSION = "".join(ROCM_VERSION.split("."))[:3]

if not PYPI_BUILD:
    if CUDA_VERSION:
Casper's avatar
Casper committed
52
        AUTOAWQ_VERSION += f"+cu{CUDA_VERSION}"
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
53
54
55
56
57
58
    elif ROCM_VERSION:
        AUTOAWQ_VERSION += f"+rocm{ROCM_VERSION}"
    else:
        raise RuntimeError(
            "Your system must have either Nvidia or AMD GPU to build this package."
        )
Casper's avatar
Casper committed
59

qwopqwop200's avatar
qwopqwop200 committed
60
common_setup_kwargs = {
Casper's avatar
Casper committed
61
    "version": AUTOAWQ_VERSION,
qwopqwop200's avatar
qwopqwop200 committed
62
63
64
65
66
    "name": "autoawq",
    "author": "Casper Hansen",
    "license": "MIT",
    "python_requires": ">=3.8.0",
    "description": "AutoAWQ implements the AWQ algorithm for 4-bit quantization with a 2x speedup during inference.",
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
67
68
69
    "long_description": (Path(__file__).parent / "README.md").read_text(
        encoding="UTF-8"
    ),
qwopqwop200's avatar
qwopqwop200 committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    "long_description_content_type": "text/markdown",
    "url": "https://github.com/casper-hansen/AutoAWQ",
    "keywords": ["awq", "autoawq", "quantization", "transformers"],
    "platforms": ["linux", "windows"],
    "classifiers": [
        "Environment :: GPU :: NVIDIA CUDA :: 11.8",
        "Environment :: GPU :: NVIDIA CUDA :: 12",
        "License :: OSI Approved :: MIT License",
        "Natural Language :: English",
        "Programming Language :: Python :: 3.8",
        "Programming Language :: Python :: 3.9",
        "Programming Language :: Python :: 3.10",
        "Programming Language :: Python :: 3.11",
        "Programming Language :: C++",
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
84
    ],
qwopqwop200's avatar
qwopqwop200 committed
85
86
87
}

requirements = [
Casper's avatar
Casper committed
88
    "torch>=2.0.1",
89
    "transformers>=4.35.0",
qwopqwop200's avatar
qwopqwop200 committed
90
91
    "tokenizers>=0.12.1",
    "accelerate",
Casper's avatar
Casper committed
92
    "datasets",
qwopqwop200's avatar
qwopqwop200 committed
93
94
]

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
try:
    importlib.metadata.version("autoawq-kernels")
    KERNELS_INSTALLED = True
except importlib.metadata.PackageNotFoundError:
    KERNELS_INSTALLED = False

# kernels can be downloaded from pypi for cuda+121 only
# for everything else, we need to download the wheels from github
if not KERNELS_INSTALLED and (CUDA_VERSION or ROCM_VERSION):
    if CUDA_VERSION.startswith("12"):
        requirements.append("autoawq-kernels")
    elif CUDA_VERSION.startswith("11") or ROCM_VERSION in ["561", "571"]:
        gpu_system_version = (
            f"cu{CUDA_VERSION}" if CUDA_VERSION else f"rocm{ROCM_VERSION}"
        )
        kernels_version = get_latest_kernels_version("casper-hansen/AutoAWQ_kernels")
        python_version = "".join(platform.python_version_tuple()[:2])
        platform_name = platform.system().lower()
        architecture = platform.machine().lower()
        latest_rocm_kernels_wheels = get_kernels_whl_url(
            gpu_system_version,
            kernels_version,
            python_version,
            platform_name,
            architecture,
        )
        requirements.append(f"autoawq-kernels@{latest_rocm_kernels_wheels}")
    else:
        raise RuntimeError(
            "Your system have a GPU with an unsupported CUDA or ROCm version. "
            "Please install the kernels manually from https://github.com/casper-hansen/AutoAWQ_kernels"
        )
127

qwopqwop200's avatar
qwopqwop200 committed
128
129
130
setup(
    packages=find_packages(),
    install_requires=requirements,
Casper's avatar
Casper committed
131
    extras_require={
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
132
        "eval": ["lm_eval>=0.4.0", "tabulate", "protobuf", "evaluate", "scipy"],
Casper's avatar
Casper committed
133
    },
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
134
    **common_setup_kwargs,
135
)