Commit 7fce4b80 authored by gaoqiong's avatar gaoqiong
Browse files

根据DCU特性修改部分代码

parent 665a401e
...@@ -60,8 +60,15 @@ cd dist && pip3 install autoawq* ...@@ -60,8 +60,15 @@ cd dist && pip3 install autoawq*
| Baichuan | 7B/13B | | Baichuan | 7B/13B |
| QWen | 1.8B/7B/14/72B | | QWen | 1.8B/7B/14/72B |
## 验证
- python -c "import awq; print(awq.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.2.5;
## Known Issue
-
## 参考资料
- [README](README.md)
- [https://github.com/casper-hansen/AutoAWQ](https://github.com/casper-hansen/AutoAWQ.git)
......
...@@ -5,7 +5,7 @@ from transformers import AutoTokenizer, TextStreamer ...@@ -5,7 +5,7 @@ from transformers import AutoTokenizer, TextStreamer
quant_path = "casperhansen/llama-3-8b-instruct-awq" quant_path = "casperhansen/llama-3-8b-instruct-awq"
# Load model # Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True) model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=False,use_exllama_v2=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
......
...@@ -8,6 +8,10 @@ from setuptools import setup, find_packages ...@@ -8,6 +8,10 @@ from setuptools import setup, find_packages
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
from typing import Optional, Union from typing import Optional, Union
add_git_version = False
if int(os.environ.get('ADD_GIT_VERSION', '0')) == 1:
add_git_version = True
def get_latest_kernels_version(repo): def get_latest_kernels_version(repo):
""" """
Get the latest version of the kernels from the github repo. Get the latest version of the kernels from the github repo.
...@@ -50,16 +54,21 @@ def get_abi(): ...@@ -50,16 +54,21 @@ def get_abi():
def get_version_add(sha: Optional[str] = None) -> str: def get_version_add(sha: Optional[str] = None) -> str:
command = "git config --global --add safe.directory "+pwd
result = subprocess.run(command, shell=True, capture_output=False, text=True)
version='' version=''
autoawq_root = os.path.dirname(os.path.abspath(__file__)) autoawq_root = os.path.dirname(os.path.abspath(__file__))
add_version_path = os.path.join(os.path.join(autoawq_root, "awq"), "__init__.py") add_version_path = os.path.join(os.path.join(autoawq_root, "awq"), "__init__.py")
if add_git_version:
if sha != 'Unknown': if sha != 'Unknown':
if sha is None: if sha is None:
sha = get_sha(autoawq_root) sha = get_sha(autoawq_root)
version = 'git' + sha[:7] version = 'das.opt1.' + sha[:7]
else:
version = 'das.opt1'
# abi # abi
version += "." + get_abi() #version += "." + get_abi()
# dtk version # dtk version
if os.getenv("ROCM_PATH"): if os.getenv("ROCM_PATH"):
...@@ -78,7 +87,7 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -78,7 +87,7 @@ def get_version_add(sha: Optional[str] = None) -> str:
with open(add_version_path, 'r',encoding='utf-8') as file: with open(add_version_path, 'r',encoding='utf-8') as file:
lines = file.readlines() lines = file.readlines()
lines[1] = "__dcu_version__ = '0.2.5+das1.1.{}'\n".format(version) lines[1] = "__dcu_version__ = '0.2.5+{}'\n".format(version)
with open(add_version_path, encoding="utf-8",mode="w") as file: with open(add_version_path, encoding="utf-8",mode="w") as file:
file.writelines(lines) file.writelines(lines)
file.close() file.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment