infer_transformers.py 1.44 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import argparse
import torch

from transformers import AutoProcessor, Llama4ForConditionalGeneration

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_id", type=str, default="meta-llama/Llama-4-Scout-17B-16E-Instruct")

    args = parser.parse_args()

    return args


if __name__ == "__main__":
    # 获取参数信息
    args = get_args()
    processor = AutoProcessor.from_pretrained(args.model_id)
    model = Llama4ForConditionalGeneration.from_pretrained(
        args.model_id,
        attn_implementation="flex_attention",
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
chenych's avatar
chenych committed
25
26
    url1 = "datasets/rabbit.jpg"
    url2 = "datasets/cat_style_layout.png"
chenych's avatar
chenych committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "url": url1},
                {"type": "image", "url": url2},
                {"type": "text", "text": "Can you describe how these two images are similar, and how they differ?"},
            ]
        },
    ]
    # 模板转换
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
    )

    response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
    print(response)
    print(outputs[0])