Unverified Commit be0f9e01 authored by Casper's avatar Casper Committed by GitHub
Browse files

Fix workflow (#345)

parent 171c7aff
#!/bin/bash #!/bin/bash
# Set variables # Set variables
AWQ_VERSION="0.1.6" AWQ_VERSION="0.2.0"
RELEASE_URL="https://api.github.com/repos/casper-hansen/AutoAWQ/releases/tags/v${AWQ_VERSION}" RELEASE_URL="https://api.github.com/repos/casper-hansen/AutoAWQ/releases/tags/v${AWQ_VERSION}"
# Create a directory to download the wheels # Create a directory to download the wheels
......
...@@ -2,9 +2,9 @@ import os ...@@ -2,9 +2,9 @@ import os
import torch import torch
import platform import platform
import requests import requests
import importlib_metadata
from pathlib import Path from pathlib import Path
from setuptools import setup, find_packages from setuptools import setup, find_packages
from torch.utils.cpp_extension import CUDAExtension
def get_latest_kernels_version(repo): def get_latest_kernels_version(repo):
...@@ -88,15 +88,20 @@ requirements = [ ...@@ -88,15 +88,20 @@ requirements = [
"torch>=2.0.1", "torch>=2.0.1",
"transformers>=4.35.0", "transformers>=4.35.0",
"tokenizers>=0.12.1", "tokenizers>=0.12.1",
"typing_extensions>=4.8.0"
"accelerate", "accelerate",
"datasets", "datasets",
"zstandard", "zstandard",
] ]
try: try:
importlib_metadata.version("autoawq-kernels") if ROCM_VERSION:
import exlv2_ext
else:
import awq_ext
KERNELS_INSTALLED = True KERNELS_INSTALLED = True
except importlib_metadata.PackageNotFoundError: except ImportError:
KERNELS_INSTALLED = False KERNELS_INSTALLED = False
# kernels can be downloaded from pypi for cuda+121 only # kernels can be downloaded from pypi for cuda+121 only
...@@ -133,5 +138,14 @@ setup( ...@@ -133,5 +138,14 @@ setup(
"eval": ["lm_eval>=0.4.0", "tabulate", "protobuf", "evaluate", "scipy"], "eval": ["lm_eval>=0.4.0", "tabulate", "protobuf", "evaluate", "scipy"],
"dev": ["black", "mkdocstrings-python", "mkdocs-material", "griffe-typingdoc"] "dev": ["black", "mkdocstrings-python", "mkdocs-material", "griffe-typingdoc"]
}, },
# NOTE: We create an empty CUDAExtension because torch helps us with
# creating the right boilerplate to enable correct targeting of
# the autoawq-kernels package
ext_modules=[
CUDAExtension(
name="__build_artifact_for_awq_kernel_targeting",
sources=[],
)
],
**common_setup_kwargs, **common_setup_kwargs,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment