Commit 7fce4b80 authored by gaoqiong's avatar gaoqiong
Browse files

根据DCU特性修改部分代码

parent 665a401e
......@@ -60,8 +60,15 @@ cd dist && pip3 install autoawq*
| Baichuan | 7B/13B |
| QWen | 1.8B/7B/14/72B |
## 验证
- python -c "import awq; print(awq.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.2.5;
## Known Issue
-
## 参考资料
- [README](README.md)
- [https://github.com/casper-hansen/AutoAWQ](https://github.com/casper-hansen/AutoAWQ.git)
......
......@@ -5,7 +5,7 @@ from transformers import AutoTokenizer, TextStreamer
quant_path = "casperhansen/llama-3-8b-instruct-awq"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True)
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=False,use_exllama_v2=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
......
zstandard
transformers==4.42.3
\ No newline at end of file
transformers==4.42.3
datasets
\ No newline at end of file
......@@ -8,6 +8,10 @@ from setuptools import setup, find_packages
from torch.utils.cpp_extension import CUDAExtension
from typing import Optional, Union
add_git_version = False
if int(os.environ.get('ADD_GIT_VERSION', '0')) == 1:
add_git_version = True
def get_latest_kernels_version(repo):
"""
Get the latest version of the kernels from the github repo.
......@@ -50,16 +54,21 @@ def get_abi():
def get_version_add(sha: Optional[str] = None) -> str:
command = "git config --global --add safe.directory "+pwd
result = subprocess.run(command, shell=True, capture_output=False, text=True)
version=''
autoawq_root = os.path.dirname(os.path.abspath(__file__))
add_version_path = os.path.join(os.path.join(autoawq_root, "awq"), "__init__.py")
if sha != 'Unknown':
if sha is None:
sha = get_sha(autoawq_root)
version = 'git' + sha[:7]
if add_git_version:
if sha != 'Unknown':
if sha is None:
sha = get_sha(autoawq_root)
version = 'das.opt1.' + sha[:7]
else:
version = 'das.opt1'
# abi
version += "." + get_abi()
#version += "." + get_abi()
# dtk version
if os.getenv("ROCM_PATH"):
......@@ -78,7 +87,7 @@ def get_version_add(sha: Optional[str] = None) -> str:
with open(add_version_path, 'r',encoding='utf-8') as file:
lines = file.readlines()
lines[1] = "__dcu_version__ = '0.2.5+das1.1.{}'\n".format(version)
lines[1] = "__dcu_version__ = '0.2.5+{}'\n".format(version)
with open(add_version_path, encoding="utf-8",mode="w") as file:
file.writelines(lines)
file.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment