reference_hf.py 2.07 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import argparse
import os

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


@torch.inference_mode()
def normal_text(args):
    t = AutoTokenizer.from_pretrained(args.model_path)
    m = AutoModelForCausalLM.from_pretrained(
        args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
    )
    m.cuda()

    print(m)

    prompts = [
        "The capital of France is",
        "The capital of the United Kindom is",
        "Today is a sunny day and I like",
    ]
    max_new_tokens = 32

    for p in prompts:
        if isinstance(p, str):
            input_ids = t.encode(p, return_tensors="pt").cuda()
        else:
            input_ids = torch.tensor([p], device="cuda")

        output_ids = m.generate(
            input_ids, do_sample=False, max_new_tokens=max_new_tokens
        )
        output_str = t.decode(output_ids[0])
        print(output_str)

        prefill_logits = m.forward(input_ids).logits[0][-1]
        print("prefill logits", prefill_logits)


@torch.inference_mode()
def synthetic_tokens(args):
    t = AutoTokenizer.from_pretrained(args.model_path)
    m = AutoModelForCausalLM.from_pretrained(
        args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
    )
    m.cuda()
    print(m)

    input_len = 256
    output_len = 8
    prompts = [list(range(5, 5 + input_len))]

    for p in prompts:
        input_ids = p
        for i in range(output_len + 1):
            prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[
                0
            ][-1]

            if i == 0:
                print("prefill logits", prefill_logits)
            else:
                print("decode", i - 1, prefill_logits)

            input_ids.append(torch.argmax(prefill_logits).item())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        type=str,
        default="TinyLlama/TinyLlama-1.1B-Chat-v0.4",
        # default="meta-llama/Llama-2-7b-chat-hf",
    )
    args = parser.parse_args()

    normal_text(args)
    # synthetic_tokens(args)