infer_transformers.py 986 Bytes
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from  transformers import AutoTokenizer, AutoModelForCausalLM


if __name__ == '__main__':
    model_name = "baidu/ERNIE-4.5-0.3B-PT"
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)

    prompt = "Give me a short introduction to large language model."
    messages = [
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([text], add_special_tokens=True, return_tensors="pt")

    # conduct text completion
    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=1024
    )
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()

    # decode the generated ids
    generate_text = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
    print("generate_text:", generate_text)