test_llama_extend.py 3.84 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import multiprocessing
import os
import time

import numpy as np
import torch
import torch.distributed as dist
import transformers
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig
from sglang.srt.sampling_params import SamplingParams


def test_generate_worker(model_path, tp_rank, tp_size):
    model_config = ModelConfig(path=model_path)
    model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)

    # Input
    prompts = [
        "The capital of France is",
        "Today is a sunny day and I like",
    ]
    sampling_params = SamplingParams(temperature=0)

    cut_num = 4

    reqs = []
    for i in range(len(prompts)):
        req = Req(i)
        req.input_ids = tokenizer.encode(prompts[i])[:cut_num]
        req.sampling_params = sampling_params
        reqs.append(req)

    # Prefill
    batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
    batch.init_extend_batch(model.model_config.vocab_size(), None)
    logits, _ = model.forward(batch, ForwardMode.EXTEND)
    next_token_ids, next_token_probs = batch.sample(logits)
    print("extend logits (first)", logits)

    # Extend
    for i in range(len(prompts)):
        req = reqs[i]
        req.input_ids += tokenizer.encode(prompts[i])[cut_num:]
        req.prefix_indices = model.req_to_token_pool.req_to_token[
            batch.req_pool_indices[i], :cut_num
        ]
    batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
    batch.init_extend_batch(model.model_config.vocab_size(), None)
    logits, _ = model.forward(batch, ForwardMode.EXTEND)
    next_token_ids, next_token_probs = batch.sample(logits)

    print("extend logits", logits)
    print(
        "next_token_ids", next_token_ids, [tokenizer.decode(x) for x in next_token_ids]
    )

    # Decode
    for i in range(6):
        batch.update_for_decode(next_token_ids.cpu().numpy())
        logits = model.forward(batch, ForwardMode.DECODE)
        next_token_ids, next_token_probs = batch.sample(logits)

        print(
            "next_token_ids",
            next_token_ids,
            [tokenizer.decode(x) for x in next_token_ids],
        )


def test_generate(model_path, tp_size):
    workers = []
    for tp_rank in range(tp_size):
        proc = multiprocessing.Process(
            target=test_generate_worker,
            args=(
                model_path,
                tp_rank,
                tp_size,
            ),
        )
        proc.start()
        workers.append(proc)

    for proc in workers:
        proc.join()


if __name__ == "__main__":
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    test_generate("TinyLlama/TinyLlama-1.1B-Chat-v0.4", 1)

    # Reference output for TinyLlama-1.1B-Chat-v0.4
    # extend logits (first) tensor([[-10.0312,  -9.5000,   0.8896,  ...,  -4.9375,  -3.2402,  -3.3633],
    #             [ -9.1797, -10.2500,   2.7168,  ...,  -4.3359,  -4.0664,  -4.1289]],
    #                    device='cuda:0', dtype=torch.float16)
    # extend logits tensor([[-8.3125, -7.1172,  3.3359,  ..., -4.9531, -4.1289, -3.4121],
    #             [-9.6406, -9.0547,  4.0195,  ..., -5.3086, -4.7188, -4.4609]],
    #                    device='cuda:0', dtype=torch.float16)
    # next_token_ids tensor([3681,  304], device='cuda:0') ['Paris', 'to']
    # next_token_ids tensor([29889,   748], device='cuda:0') ['.', 'go']
    # next_token_ids tensor([ 13, 363], device='cuda:0') ['\n', 'for']
    # next_token_ids tensor([1576,  263], device='cuda:0') ['The', 'a']
    # next_token_ids tensor([7483, 6686], device='cuda:0') ['capital', 'walk']
    # next_token_ids tensor([310, 297], device='cuda:0') ['of', 'in']
    # next_token_ids tensor([278, 278], device='cuda:0') ['the', 'the']