simple_inference.py 645 Bytes
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from transformers import AutoModelForCausalLM, AutoTokenizer

model_path = "/home/modelzoo/HuatuoGPT-o1/weights/HuatuoGPT-o1-7B-Qwen"

model = AutoModelForCausalLM.from_pretrained(model_path,torch_dtype="auto",device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)

input_text = "孩子咳嗽老不好怎么办?"
messages = [{"role": "user", "content": input_text}]

inputs = tokenizer(tokenizer.apply_chat_template(messages, tokenize=False,add_generation_prompt=True
), return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=2048)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))