infer_transformers.py 1.63 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import argparse
import torch

from transformers import AutoProcessor, Llama4ForConditionalGeneration

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_id", type=str, default="meta-llama/Llama-4-Scout-17B-16E-Instruct")

    args = parser.parse_args()

    return args


if __name__ == "__main__":
    # 获取参数信息
    args = get_args()
    processor = AutoProcessor.from_pretrained(args.model_id)
    model = Llama4ForConditionalGeneration.from_pretrained(
        args.model_id,
        attn_implementation="flex_attention",
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )

    url1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
    url2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "url": url1},
                {"type": "image", "url": url2},
                {"type": "text", "text": "Can you describe how these two images are similar, and how they differ?"},
            ]
        },
    ]
    # 模板转换
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
    )

    response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
    print(response)
    print(outputs[0])