Unverified Commit 19d7e630 authored by Fengzhe Zhou's avatar Fengzhe Zhou Committed by GitHub
Browse files

[Sync] Update accelerator (#1122)



(cherry picked from commit 4beb6d9ab655d8a626971841b7acfd9fae9d438f)
Co-authored-by: default avatarliuhongwei <liuhongwei@pjlab.org.cn>
parent a71122ee
......@@ -53,8 +53,8 @@ def parse_args():
parser.add_argument(
'--accelerator',
help='Infer accelerator, support vllm and lmdeploy now.',
choices=['vllm', 'lmdeploy', 'hg'],
default='hg',
choices=['vllm', 'lmdeploy', 'hf'],
default='hf',
type=str)
parser.add_argument('-m',
'--mode',
......
......@@ -220,7 +220,7 @@ def change_accelerator(models, accelerator):
if accelerator == 'lmdeploy':
get_logger().info(
f'Transforming {model["abbr"]} to {accelerator}')
model = dict(
acc_model = dict(
type= # noqa E251
f'{TurboMindModel.__module__}.{TurboMindModel.__name__}',
abbr=model['abbr'].replace('hf', 'lmdeploy')
......@@ -242,12 +242,12 @@ def change_accelerator(models, accelerator):
)
for item in ['meta_template']:
if model.get(item) is not None:
model.update(item, model[item])
acc_model[item] = model[item]
elif accelerator == 'vllm':
get_logger().info(
f'Transforming {model["abbr"]} to {accelerator}')
model = dict(
acc_model = dict(
type=f'{VLLM.__module__}.{VLLM.__name__}',
abbr=model['abbr'].replace('hf', 'vllm')
if '-hf' in model['abbr'] else model['abbr'] + '-vllm',
......@@ -262,12 +262,10 @@ def change_accelerator(models, accelerator):
)
for item in ['meta_template', 'end_str']:
if model.get(item) is not None:
model.update(item, model[item])
generation_kwargs.update(
dict(temperature=gen_args['temperature']))
acc_model[item] = model[item]
else:
raise ValueError(f'Unsupported accelerator {accelerator}')
model_accels.append(model)
model_accels.append(acc_model)
return model_accels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment