infer.py 678 Bytes
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
chenzk's avatar
v2.0  
chenzk committed
3
import time
chenzk's avatar
v1.0  
chenzk committed
4
5
torch.manual_seed(0)

chenzk's avatar
v2.0  
chenzk committed
6
7
#path = "output/AdvertiseGenLoRA_lora_finetune/xxx/checkpoint-3000" # xxx:系统时间路径
path = 'checkpoint/miniCPM-bf16'
chenzk's avatar
v1.0  
chenzk committed
8
9
10
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map='cuda', trust_remote_code=True)

chenzk's avatar
v2.0  
chenzk committed
11
start_time = time.time()
chenzk's avatar
v1.0  
chenzk committed
12
responds, history = model.chat(tokenizer, "山东省最高的山是哪座山, 它比黄山高还是矮?差距多少?", temperature=0.5, top_p=0.8, repetition_penalty=1.02)
chenzk's avatar
v2.0  
chenzk committed
13
print("infer time:", time.time() - start_time, "s")
chenzk's avatar
v1.0  
chenzk committed
14
print(responds)