Commit 9a7de7de authored by chenych's avatar chenych
Browse files

Modify inference.py and README.

parent 74c7c1cc
...@@ -106,7 +106,10 @@ HIP_VISIBLE_DEVICES=0,1 FORCE_TORCHRUN=1 llamafactory-cli train examples/train_l ...@@ -106,7 +106,10 @@ HIP_VISIBLE_DEVICES=0,1 FORCE_TORCHRUN=1 llamafactory-cli train examples/train_l
### 单机单卡 ### 单机单卡
```bash ```bash
python inference.py --model_path /path/of/gemma2 # 指定卡号
export HIP_VISIBLE_DEVICES=0,1
# 根据实际情况修改max_new_tokens参数
python inference.py --model_path /path/of/gemma2 --max_new_tokens xxx
``` ```
## result ## result
...@@ -114,7 +117,7 @@ python inference.py --model_path /path/of/gemma2 ...@@ -114,7 +117,7 @@ python inference.py --model_path /path/of/gemma2
- 模型:gemma-2-9b - 模型:gemma-2-9b
<div align=center> <div align=center>
<img src="./docs/results.png" witdh=1200 height=400/> <img src="./docs/results.png"/>
</div> </div>
### 精度 ### 精度
......
...@@ -3,10 +3,8 @@ import argparse ...@@ -3,10 +3,8 @@ import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
# 卡号指定
os.environ['HIP_VISIBLE_DEVICES'] = '0'
def infer_hf(model_path, input_text): def infer_hf(model_path, input_text, max_new_token=32):
''' transformers 推理 gemma2''' ''' transformers 推理 gemma2'''
tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
...@@ -16,7 +14,7 @@ def infer_hf(model_path, input_text): ...@@ -16,7 +14,7 @@ def infer_hf(model_path, input_text):
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=1024) outputs = model.generate(**input_ids, max_new_tokens=max_new_token)
print(tokenizer.decode(outputs[0])) print(tokenizer.decode(outputs[0]))
...@@ -26,10 +24,11 @@ def parse_args(): ...@@ -26,10 +24,11 @@ def parse_args():
default='Write me a poem about Machine Learning.', default='Write me a poem about Machine Learning.',
help='') help='')
parser.add_argument('--model_path', default='/path/of/gemma2') parser.add_argument('--model_path', default='/path/of/gemma2')
parser.add_argument('--max_new_tokens', default=32, type=int)
return parser.parse_args() return parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
infer_hf(args.model_path, args.input_text) infer_hf(args.model_path, args.input_text, args.max_new_tokens)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment