reference_hf.py 3.22 KB
Newer Older
1
2
3
4
5
"""
Usage:
python3 reference_hf.py --model TinyLlama/TinyLlama-1.1B-Chat-v0.4

Reference output:
6
7
8
========== Prompt 0 ==========
prefill logits (final) tensor([-8.3125, -7.1172,  3.3398,  ..., -4.9531, -4.1328, -3.4141],
       device='cuda:0')
9
10
<s> The capital of France is Paris.
The capital of the United States is Washington, D.C.
11
12
13

========== Prompt 1 ==========
prefill logits (final) tensor([-8.9062, -9.0156,  4.1484,  ..., -4.9922, -4.4961, -4.0742],
14
15
16
       device='cuda:0')
<s> The capital of the United Kindom is London.
The capital of the United Kingdom is London.
17
18
19
20
The capital of

========== Prompt 2 ==========
prefill logits (final) tensor([-9.6328, -9.0547,  4.0234,  ..., -5.3047, -4.7148, -4.4609],
21
22
       device='cuda:0')
<s> Today is a sunny day and I like to go for a walk in the park.
23
I'm going to the
24
25
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
28
import argparse

import torch
29
30
31
from transformers import AutoModelForCausalLM

from sglang.srt.hf_transformers_utils import get_tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
32
33
34
35


@torch.inference_mode()
def normal_text(args):
36
    t = get_tokenizer(args.model_path, trust_remote_code=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
37
    m = AutoModelForCausalLM.from_pretrained(
zhyncs's avatar
zhyncs committed
38
        args.model_path,
39
        torch_dtype=args.dtype,
zhyncs's avatar
zhyncs committed
40
        low_cpu_mem_usage=True,
41
        device_map="auto",
zhyncs's avatar
zhyncs committed
42
        trust_remote_code=True,
Lianmin Zheng's avatar
Lianmin Zheng committed
43
44
45
46
47
48
49
    )

    prompts = [
        "The capital of France is",
        "The capital of the United Kindom is",
        "Today is a sunny day and I like",
    ]
50
    max_new_tokens = args.max_new_tokens
51
52

    torch.cuda.set_device(0)
Lianmin Zheng's avatar
Lianmin Zheng committed
53

54
    for i, p in enumerate(prompts):
Lianmin Zheng's avatar
Lianmin Zheng committed
55
        if isinstance(p, str):
56
            input_ids = t.encode(p, return_tensors="pt").to("cuda:0")
Lianmin Zheng's avatar
Lianmin Zheng committed
57
        else:
58
            input_ids = torch.tensor([p], device="cuda:0")
Lianmin Zheng's avatar
Lianmin Zheng committed
59
60
61
62
63
64
65

        output_ids = m.generate(
            input_ids, do_sample=False, max_new_tokens=max_new_tokens
        )
        output_str = t.decode(output_ids[0])

        prefill_logits = m.forward(input_ids).logits[0][-1]
66

67
68
        print(f"\n========== Prompt {i} ==========")
        print("prefill logits (final)", prefill_logits)
69
        print(output_str)
Lianmin Zheng's avatar
Lianmin Zheng committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106


@torch.inference_mode()
def synthetic_tokens(args):
    m = AutoModelForCausalLM.from_pretrained(
        args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
    )
    m.cuda()
    print(m)

    input_len = 256
    output_len = 8
    prompts = [list(range(5, 5 + input_len))]

    for p in prompts:
        input_ids = p
        for i in range(output_len + 1):
            prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[
                0
            ][-1]

            if i == 0:
                print("prefill logits", prefill_logits)
            else:
                print("decode", i - 1, prefill_logits)

            input_ids.append(torch.argmax(prefill_logits).item())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        type=str,
        default="TinyLlama/TinyLlama-1.1B-Chat-v0.4",
        # default="meta-llama/Llama-2-7b-chat-hf",
    )
107
108
109
110
111
112
113
114
115
116
    parser.add_argument(
        "--max-new-tokens",
        type=int,
        default=16)

    parser.add_argument(
        "--dtype",
        type=str,
        default="float16")

Lianmin Zheng's avatar
Lianmin Zheng committed
117
118
119
120
    args = parser.parse_args()

    normal_text(args)
    # synthetic_tokens(args)