reference_hf.py 5.72 KB
Newer Older
1
2
3
4
5
"""
Usage:
python3 reference_hf.py --model TinyLlama/TinyLlama-1.1B-Chat-v0.4

Reference output:
6
7
8
========== Prompt 0 ==========
prefill logits (final) tensor([-8.3125, -7.1172,  3.3398,  ..., -4.9531, -4.1328, -3.4141],
       device='cuda:0')
9
10
<s> The capital of France is Paris.
The capital of the United States is Washington, D.C.
11
12
13

========== Prompt 1 ==========
prefill logits (final) tensor([-8.9062, -9.0156,  4.1484,  ..., -4.9922, -4.4961, -4.0742],
14
15
16
       device='cuda:0')
<s> The capital of the United Kindom is London.
The capital of the United Kingdom is London.
17
18
19
20
The capital of

========== Prompt 2 ==========
prefill logits (final) tensor([-9.6328, -9.0547,  4.0234,  ..., -5.3047, -4.7148, -4.4609],
21
22
       device='cuda:0')
<s> Today is a sunny day and I like to go for a walk in the park.
23
I'm going to the
24
25
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
import argparse

28
import requests
Lianmin Zheng's avatar
Lianmin Zheng committed
29
import torch
Yineng Zhang's avatar
Yineng Zhang committed
30
from PIL import Image
31
from transformers import (
Yineng Zhang's avatar
Yineng Zhang committed
32
33
34
    AutoModelForCausalLM,
    AutoModelForImageTextToText,
    AutoProcessor,
35
)
36
37

from sglang.srt.hf_transformers_utils import get_tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
38
39


40
41
42
@torch.no_grad()
def vlm_text_with_image(args):
    # Load the processor and model for ImageTextToText tasks
Yineng Zhang's avatar
Yineng Zhang committed
43
    processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True)
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    model = AutoModelForImageTextToText.from_pretrained(
        args.model_path,
        torch_dtype=args.dtype,
        low_cpu_mem_usage=True,
        device_map="auto",
        trust_remote_code=True,
    )

    torch.cuda.set_device(0)

    # List of image URLs to process
    image_urls = [
        "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
    ]

    # Conversation template for the processor
    conversation = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                },
Yineng Zhang's avatar
Yineng Zhang committed
67
68
                {"type": "text", "text": "Describe this image."},
            ],
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        }
    ]

    max_new_tokens = args.max_new_tokens

    for i, url in enumerate(image_urls):
        # Load the image from the URL
        image = Image.open(requests.get(url, stream=True).raw)

        # Apply the chat template to the text prompt
        # Notice that not all processors support chat templates.
        # LLaVA and QWen are two processors that support chat templates.
        if not hasattr(processor, "apply_chat_template"):
            raise ValueError("The processor does not support chat templates.")
        text_prompt = processor.apply_chat_template(
Yineng Zhang's avatar
Yineng Zhang committed
84
85
            conversation, add_generation_prompt=True
        )
86
87

        # Prepare inputs for the model
Yineng Zhang's avatar
Yineng Zhang committed
88
89
90
        inputs = processor(text=[text_prompt], images=[image], return_tensors="pt").to(
            "cuda:0"
        )
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

        # Generate output from the model
        output_ids = model.generate(
            **inputs, do_sample=False, max_new_tokens=max_new_tokens
        )
        output_str = processor.decode(output_ids[0])

        # Get the logits from the model's forward pass
        outputs = model.forward(**inputs)
        logits = outputs.logits[0, -1, :]

        print(f"\n========== Image {i} ==========")
        print("prefill logits (final)", logits)
        # TODO(gaocegege): The output contains numerous <|image_pad|> tokens,
        # making it cluttered and difficult to read.
        # These tokens should be removed or cleaned up for better readability.
        print(output_str)


110
@torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
111
def normal_text(args):
112
    t = get_tokenizer(args.model_path, trust_remote_code=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
113
    m = AutoModelForCausalLM.from_pretrained(
zhyncs's avatar
zhyncs committed
114
        args.model_path,
115
        torch_dtype=args.dtype,
zhyncs's avatar
zhyncs committed
116
        low_cpu_mem_usage=True,
117
        device_map="auto",
zhyncs's avatar
zhyncs committed
118
        trust_remote_code=True,
Lianmin Zheng's avatar
Lianmin Zheng committed
119
120
121
122
123
124
125
    )

    prompts = [
        "The capital of France is",
        "The capital of the United Kindom is",
        "Today is a sunny day and I like",
    ]
126
    max_new_tokens = args.max_new_tokens
127
128

    torch.cuda.set_device(0)
Lianmin Zheng's avatar
Lianmin Zheng committed
129

130
    for i, p in enumerate(prompts):
Lianmin Zheng's avatar
Lianmin Zheng committed
131
        if isinstance(p, str):
132
            input_ids = t.encode(p, return_tensors="pt").to("cuda:0")
Lianmin Zheng's avatar
Lianmin Zheng committed
133
        else:
134
            input_ids = torch.tensor([p], device="cuda:0")
Lianmin Zheng's avatar
Lianmin Zheng committed
135
136
137
138
139
140
141

        output_ids = m.generate(
            input_ids, do_sample=False, max_new_tokens=max_new_tokens
        )
        output_str = t.decode(output_ids[0])

        prefill_logits = m.forward(input_ids).logits[0][-1]
142

143
144
        print(f"\n========== Prompt {i} ==========")
        print("prefill logits (final)", prefill_logits)
145
        print(output_str)
Lianmin Zheng's avatar
Lianmin Zheng committed
146
147


148
@torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def synthetic_tokens(args):
    m = AutoModelForCausalLM.from_pretrained(
        args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
    )
    m.cuda()
    print(m)

    input_len = 256
    output_len = 8
    prompts = [list(range(5, 5 + input_len))]

    for p in prompts:
        input_ids = p
        for i in range(output_len + 1):
            prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[
                0
            ][-1]

            if i == 0:
                print("prefill logits", prefill_logits)
            else:
                print("decode", i - 1, prefill_logits)

            input_ids.append(torch.argmax(prefill_logits).item())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        type=str,
        default="TinyLlama/TinyLlama-1.1B-Chat-v0.4",
        # default="meta-llama/Llama-2-7b-chat-hf",
    )
Chayenne's avatar
Chayenne committed
183
    parser.add_argument("--max-new-tokens", type=int, default=16)
184

Chayenne's avatar
Chayenne committed
185
    parser.add_argument("--dtype", type=str, default="float16")
186

187
188
    parser.add_argument("--model-type", type=str, default="text")

Lianmin Zheng's avatar
Lianmin Zheng committed
189
190
    args = parser.parse_args()

191
192
193
194
    if args.model_type == "vlm":
        vlm_text_with_image(args)
    else:
        normal_text(args)