Commit 9a7de7de authored by chenych's avatar chenych
Browse files

Modify inference.py and README.

parent 74c7c1cc
......@@ -106,7 +106,10 @@ HIP_VISIBLE_DEVICES=0,1 FORCE_TORCHRUN=1 llamafactory-cli train examples/train_l
### 单机单卡
```bash
python inference.py --model_path /path/of/gemma2
# 指定卡号
export HIP_VISIBLE_DEVICES=0,1
# 根据实际情况修改max_new_tokens参数
python inference.py --model_path /path/of/gemma2 --max_new_tokens xxx
```
## result
......@@ -114,7 +117,7 @@ python inference.py --model_path /path/of/gemma2
- 模型:gemma-2-9b
<div align=center>
<img src="./docs/results.png" witdh=1200 height=400/>
<img src="./docs/results.png"/>
</div>
### 精度
......
......@@ -3,10 +3,8 @@ import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
# 卡号指定
os.environ['HIP_VISIBLE_DEVICES'] = '0'
def infer_hf(model_path, input_text):
def infer_hf(model_path, input_text, max_new_token=32):
''' transformers 推理 gemma2'''
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
......@@ -16,7 +14,7 @@ def infer_hf(model_path, input_text):
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=1024)
outputs = model.generate(**input_ids, max_new_tokens=max_new_token)
print(tokenizer.decode(outputs[0]))
......@@ -26,10 +24,11 @@ def parse_args():
default='Write me a poem about Machine Learning.',
help='')
parser.add_argument('--model_path', default='/path/of/gemma2')
parser.add_argument('--max_new_tokens', default=32, type=int)
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
infer_hf(args.model_path, args.input_text)
infer_hf(args.model_path, args.input_text, args.max_new_tokens)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment