Unverified Commit 77efebbf authored by Chen Xin's avatar Chen Xin Committed by GitHub
Browse files

convert model with hf repo_id (#774)

parent 5c9e1e28
...@@ -23,13 +23,15 @@ class CLI(object): ...@@ -23,13 +23,15 @@ class CLI(object):
dst_path: str = './workspace', dst_path: str = './workspace',
tp: int = 1, tp: int = 1,
quant_path: str = None, quant_path: str = None,
group_size: int = 0): group_size: int = 0,
**kwargs):
"""Convert LLMs to lmdeploy format. """Convert LLMs to lmdeploy format.
Args: Args:
model_name (str): The name of the to-be-deployed model, such as model_name (str): The name of the to-be-deployed model, such as
llama-7b, llama-13b, vicuna-7b and etc. llama-7b, llama-13b, vicuna-7b and etc.
model_path (str): The directory path of the model model_path (str): The directory path of the model or huggingface
repo_id like 'internlm/internlm-chat-20b'
model_format (str): the format of the model, should choose from model_format (str): the format of the model, should choose from
['llama', 'hf', 'awq', None]. 'llama' stands for META's llama ['llama', 'hf', 'awq', None]. 'llama' stands for META's llama
format, 'hf' means huggingface llama format, and 'awq' means format, 'hf' means huggingface llama format, and 'awq' means
...@@ -43,6 +45,7 @@ class CLI(object): ...@@ -43,6 +45,7 @@ class CLI(object):
quant_path (str): Path of the quantized model, which can be None. quant_path (str): Path of the quantized model, which can be None.
group_size (int): A parameter used in AWQ to quantize fp16 weights group_size (int): A parameter used in AWQ to quantize fp16 weights
to 4 bits. to 4 bits.
kwargs (dict): other params for convert
""" """
from lmdeploy.turbomind.deploy.converter import main as convert from lmdeploy.turbomind.deploy.converter import main as convert
...@@ -53,7 +56,8 @@ class CLI(object): ...@@ -53,7 +56,8 @@ class CLI(object):
dst_path=dst_path, dst_path=dst_path,
tp=tp, tp=tp,
quant_path=quant_path, quant_path=quant_path,
group_size=group_size) group_size=group_size,
**kwargs)
def list(self, engine: str = 'turbomind'): def list(self, engine: str = 'turbomind'):
"""List supported model names. """List supported model names.
......
...@@ -6,8 +6,10 @@ import shutil ...@@ -6,8 +6,10 @@ import shutil
from pathlib import Path from pathlib import Path
import fire import fire
from huggingface_hub import snapshot_download
from lmdeploy.model import MODELS from lmdeploy.model import MODELS
from lmdeploy.turbomind.utils import create_hf_download_args
from .source_model.base import INPUT_MODELS from .source_model.base import INPUT_MODELS
from .target_model.base import OUTPUT_MODELS, TurbomindModelConfig from .target_model.base import OUTPUT_MODELS, TurbomindModelConfig
...@@ -143,7 +145,8 @@ def main(model_name: str, ...@@ -143,7 +145,8 @@ def main(model_name: str,
dst_path: str = 'workspace', dst_path: str = 'workspace',
tp: int = 1, tp: int = 1,
quant_path: str = None, quant_path: str = None,
group_size: int = 0): group_size: int = 0,
**kwargs):
"""deploy llama family models via turbomind. """deploy llama family models via turbomind.
Args: Args:
...@@ -162,6 +165,7 @@ def main(model_name: str, ...@@ -162,6 +165,7 @@ def main(model_name: str,
quant_path (str): Path of the quantized model, which can be None. quant_path (str): Path of the quantized model, which can be None.
group_size (int): a parameter used in AWQ to quantize fp16 weights group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits to 4 bits
kwargs (dict): other params for convert
""" """
assert model_name in MODELS.module_dict.keys(), \ assert model_name in MODELS.module_dict.keys(), \
...@@ -184,6 +188,13 @@ def main(model_name: str, ...@@ -184,6 +188,13 @@ def main(model_name: str,
f'which is not in supported list {supported_keys}') f'which is not in supported list {supported_keys}')
exit(-1) exit(-1)
if not os.path.exists(model_path):
print(f'can\'t find model from local_path {model_path}, '
'try to download from huggingface')
download_kwargs = create_hf_download_args(**kwargs)
model_path = snapshot_download(model_path, **download_kwargs)
print(f'load model from {model_path}')
# get tokenizer path # get tokenizer path
tokenizer_path = get_tokenizer_path(model_path, tokenizer_path) tokenizer_path = get_tokenizer_path(model_path, tokenizer_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment