Unverified Commit 77efebbf authored by Chen Xin's avatar Chen Xin Committed by GitHub
Browse files

convert model with hf repo_id (#774)

parent 5c9e1e28
......@@ -23,13 +23,15 @@ class CLI(object):
dst_path: str = './workspace',
tp: int = 1,
quant_path: str = None,
group_size: int = 0):
group_size: int = 0,
**kwargs):
"""Convert LLMs to lmdeploy format.
Args:
model_name (str): The name of the to-be-deployed model, such as
llama-7b, llama-13b, vicuna-7b and etc.
model_path (str): The directory path of the model
model_path (str): The directory path of the model or huggingface
repo_id like 'internlm/internlm-chat-20b'
model_format (str): the format of the model, should choose from
['llama', 'hf', 'awq', None]. 'llama' stands for META's llama
format, 'hf' means huggingface llama format, and 'awq' means
......@@ -43,6 +45,7 @@ class CLI(object):
quant_path (str): Path of the quantized model, which can be None.
group_size (int): A parameter used in AWQ to quantize fp16 weights
to 4 bits.
kwargs (dict): other params for convert
"""
from lmdeploy.turbomind.deploy.converter import main as convert
......@@ -53,7 +56,8 @@ class CLI(object):
dst_path=dst_path,
tp=tp,
quant_path=quant_path,
group_size=group_size)
group_size=group_size,
**kwargs)
def list(self, engine: str = 'turbomind'):
"""List supported model names.
......
......@@ -6,8 +6,10 @@ import shutil
from pathlib import Path
import fire
from huggingface_hub import snapshot_download
from lmdeploy.model import MODELS
from lmdeploy.turbomind.utils import create_hf_download_args
from .source_model.base import INPUT_MODELS
from .target_model.base import OUTPUT_MODELS, TurbomindModelConfig
......@@ -143,7 +145,8 @@ def main(model_name: str,
dst_path: str = 'workspace',
tp: int = 1,
quant_path: str = None,
group_size: int = 0):
group_size: int = 0,
**kwargs):
"""deploy llama family models via turbomind.
Args:
......@@ -162,6 +165,7 @@ def main(model_name: str,
quant_path (str): Path of the quantized model, which can be None.
group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits
kwargs (dict): other params for convert
"""
assert model_name in MODELS.module_dict.keys(), \
......@@ -184,6 +188,13 @@ def main(model_name: str,
f'which is not in supported list {supported_keys}')
exit(-1)
if not os.path.exists(model_path):
print(f'can\'t find model from local_path {model_path}, '
'try to download from huggingface')
download_kwargs = create_hf_download_args(**kwargs)
model_path = snapshot_download(model_path, **download_kwargs)
print(f'load model from {model_path}')
# get tokenizer path
tokenizer_path = get_tokenizer_path(model_path, tokenizer_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment