Unverified Commit be0f9e01 authored by Casper's avatar Casper Committed by GitHub
Browse files

Fix workflow (#345)

parent 171c7aff
#!/bin/bash
# Set variables
AWQ_VERSION="0.1.6"
AWQ_VERSION="0.2.0"
RELEASE_URL="https://api.github.com/repos/casper-hansen/AutoAWQ/releases/tags/v${AWQ_VERSION}"
# Create a directory to download the wheels
......
......@@ -2,9 +2,9 @@ import os
import torch
import platform
import requests
import importlib_metadata
from pathlib import Path
from setuptools import setup, find_packages
from torch.utils.cpp_extension import CUDAExtension
def get_latest_kernels_version(repo):
......@@ -88,15 +88,20 @@ requirements = [
"torch>=2.0.1",
"transformers>=4.35.0",
"tokenizers>=0.12.1",
"typing_extensions>=4.8.0"
"accelerate",
"datasets",
"zstandard",
]
try:
importlib_metadata.version("autoawq-kernels")
if ROCM_VERSION:
import exlv2_ext
else:
import awq_ext
KERNELS_INSTALLED = True
except importlib_metadata.PackageNotFoundError:
except ImportError:
KERNELS_INSTALLED = False
# kernels can be downloaded from pypi for cuda+121 only
......@@ -133,5 +138,14 @@ setup(
"eval": ["lm_eval>=0.4.0", "tabulate", "protobuf", "evaluate", "scipy"],
"dev": ["black", "mkdocstrings-python", "mkdocs-material", "griffe-typingdoc"]
},
# NOTE: We create an empty CUDAExtension because torch helps us with
# creating the right boilerplate to enable correct targeting of
# the autoawq-kernels package
ext_modules=[
CUDAExtension(
name="__build_artifact_for_awq_kernel_targeting",
sources=[],
)
],
**common_setup_kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment