infer_transformers.py 1.46 KB
Newer Older
chenych's avatar
chenych committed
1
import torch
2
import argparse
chenych's avatar
chenych committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

from transformers import AutoProcessor, Llama4ForConditionalGeneration

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_id", type=str, default="meta-llama/Llama-4-Scout-17B-16E-Instruct")

    args = parser.parse_args()

    return args


if __name__ == "__main__":
    # 获取参数信息
    args = get_args()
    processor = AutoProcessor.from_pretrained(args.model_id)
    model = Llama4ForConditionalGeneration.from_pretrained(
        args.model_id,
21
        #attn_implementation="flex_attention", # torch>2.5
chenych's avatar
chenych committed
22
23
24
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
chenych's avatar
chenych committed
25
26
    url1 = "datasets/rabbit.jpg"
    url2 = "datasets/cat_style_layout.png"
chenych's avatar
chenych committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "url": url1},
                {"type": "image", "url": url2},
                {"type": "text", "text": "Can you describe how these two images are similar, and how they differ?"},
            ]
        },
    ]
    # 模板转换
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device)
45
    ## 生成
chenych's avatar
chenych committed
46
47
48
49
50
51
52
53
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
    )

    response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
    print(response)
    print(outputs[0])