infer_transformers.py 1.19 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import argparse

from transformers import AutoModelForCausalLM, AutoTokenizer

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="THUDM/GLM-Z1-9B-0414")
    parser.add_argument("--message", type=str, default="Let a, b be positive real numbers such that ab = a + b + 3. Determine the range of possible values for a + b.")

    args = parser.parse_args()

    return args


if __name__ == "__main__":
    # 获取参数信息
    args = get_args()

    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto")

    message = [{"role": "user", "content": args.message}]

    inputs = tokenizer.apply_chat_template(
        message,
        return_tensors="pt",
        add_generation_prompt=True,
        return_dict=True,
    ).to(model.device)

    generate_kwargs = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "max_new_tokens": 4096,
        "do_sample": False,
    }
    out = model.generate(**generate_kwargs)
    print(tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True))