bench_llama_low_api.py 8.34 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
5
6
import multiprocessing as mp
import time
from dataclasses import dataclass

import torch
import torch.distributed as dist
Liangsheng Yin's avatar
Liangsheng Yin committed
7

8
from sglang.srt.managers.controller.model_runner import ModelRunner
Lianmin Zheng's avatar
Lianmin Zheng committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from sglang.srt.model_config import ModelConfig


@dataclass
class BenchBatch:
    req_to_token_pool: torch.Tensor
    token_to_kv_pool: torch.Tensor

    input_ids: torch.Tensor = None
    position_ids_offsets: torch.Tensor = None
    seq_lens: torch.Tensor = None
    prefix_lens: torch.Tensor = None
    req_pool_indices: torch.Tensor = None
    out_cache_loc: torch.Tensor = None
    out_cache_cont_start: torch.Tensor = None
    out_cache_cont_end: torch.Tensor = None

    def __init__(self, model_runner: ModelRunner):
        self.req_to_token_pool = model_runner.req_to_token_pool
        self.token_to_kv_pool = model_runner.token_to_kv_pool

    def init_prefill_batch(self, input_ids, batch_size, seq_len):
        self.input_ids = input_ids
        self.position_ids_offsets = torch.zeros(
            batch_size, dtype=torch.int32, device="cuda"
        )
        self.seq_lens = torch.full(
            (batch_size,), seq_len, dtype=torch.int32, device="cuda"
        )
        self.prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
        self.req_pool_indices = self.req_to_token_pool.alloc(batch_size)
        self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * seq_len)

        for i in range(batch_size):
            n_idx = self.req_pool_indices[i].item()
            self.req_to_token_pool.req_to_token[n_idx, :seq_len] = self.out_cache_loc[
                i * seq_len : (i + 1) * seq_len
            ]

    def update_extend(
        self, input_ids, batch_size, prefix_len, extend_len, prefix_req_idx
    ):
        self.input_ids = input_ids
        self.position_ids_offsets = torch.zeros(
            batch_size, dtype=torch.int32, device="cuda"
        )
        self.seq_lens = torch.full(
            (batch_size,), prefix_len + extend_len, dtype=torch.int32, device="cuda"
        )
        self.prefix_lens = torch.full(
            (batch_size,), prefix_len, dtype=torch.int32, device="cuda"
        )
        self.req_pool_indices = self.req_to_token_pool.alloc(batch_size)
        self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * extend_len)

        req_to_token = self.req_to_token_pool.req_to_token
        fork_num = batch_size // prefix_req_idx.shape[0]
        for i in range(batch_size):
            p_idx = prefix_req_idx[i // fork_num].item()
            n_idx = self.req_pool_indices[i].item()
            req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
Liangsheng Yin's avatar
Liangsheng Yin committed
70
71
72
            req_to_token[n_idx, prefix_len : prefix_len + extend_len] = (
                self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
73
74
75
76
77
78
79
80
81
82
83
84

    def update_decode(self, predict_ids, batch_size):
        assert predict_ids.shape[0] == batch_size
        assert batch_size == self.req_pool_indices.shape[0]

        self.input_ids = predict_ids.reshape(-1)
        self.prefix_lens = None
        (
            self.out_cache_loc,
            self.out_cache_cont_start,
            self.out_cache_cont_end,
        ) = self.token_to_kv_pool.alloc_contiguous(batch_size)
Liangsheng Yin's avatar
Liangsheng Yin committed
85
86
87
        self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
            self.out_cache_loc
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        self.seq_lens.add_(1)


def prefill(model_runner: ModelRunner, batch: BenchBatch):
    logits, _ = model_runner.forward_extend(
        batch.input_ids,
        batch.req_pool_indices,
        batch.seq_lens,
        batch.prefix_lens,
        batch.position_ids_offsets,
        batch.out_cache_loc,
        False,
    )

    prob_out = torch.softmax(logits, dim=-1)
    predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
    predict_ids = predict_ids.detach().cpu().numpy()

    return predict_ids


def extend(model_runner: ModelRunner, batch: BenchBatch):
    logits, _ = model_runner.forward_extend(
        batch.input_ids,
        batch.req_pool_indices,
        batch.seq_lens,
        batch.prefix_lens,
        batch.position_ids_offsets,
        batch.out_cache_loc,
        True,
    )

    prob_out = torch.softmax(logits, dim=-1)
    predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
    predict_ids = predict_ids.detach().cpu().numpy()

    return predict_ids


def decode(model_runner: ModelRunner, batch: BenchBatch):
    logits = model_runner.forward_decode(
        batch.input_ids,
        batch.req_pool_indices,
        batch.seq_lens,
        None,
        batch.position_ids_offsets,
        None,
        batch.out_cache_cont_start,
        batch.out_cache_cont_end,
    )

    prob_out = torch.softmax(logits, dim=-1)
    predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
    predict_ids = predict_ids.detach().cpu().numpy()

    return predict_ids


def bench_generate_worker(
    model_path,
    tp_rank,
    tp_size,
    shared_num,
    unique_num,
    shared_len,
    unique_len,
    decode_len,
Liangsheng Yin's avatar
Liangsheng Yin committed
155
    server_args_dict,
Lianmin Zheng's avatar
Lianmin Zheng committed
156
157
158
159
160
161
162
163
164
165
):
    assert unique_num % shared_num == 0

    model_config = ModelConfig(path=model_path)
    model_runner = ModelRunner(
        model_config=model_config,
        mem_fraction_static=0.8,
        tp_rank=tp_rank,
        tp_size=tp_size,
        nccl_port=28888,
Liangsheng Yin's avatar
Liangsheng Yin committed
166
        server_args_dict=server_args_dict,
Lianmin Zheng's avatar
Lianmin Zheng committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    )

    batch = BenchBatch(model_runner)

    # warm up
    for _ in range(1):
        input_ids = torch.randint(
            low=5, high=100, size=(shared_num * shared_len,)
        ).cuda()
        batch.init_prefill_batch(input_ids, shared_num, shared_len)
        _ = prefill(model_runner, batch)

        input_ids = torch.randint(
            low=5, high=100, size=(unique_num * unique_len,)
        ).cuda()
        batch.update_extend(
            input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices
        )
        predict_ids = extend(model_runner, batch)

        for i in range(decode_len):
            predict_ids = torch.from_numpy(predict_ids).cuda()
            batch.update_decode(predict_ids, unique_num)
            predict_ids = decode(model_runner, batch)

        model_runner.req_to_token_pool.clear()
        model_runner.token_to_kv_pool.clear()

    if tp_size > 1:
        dist.barrier()

    prefill_start = time.time()
    input_ids = torch.randint(low=5, high=100, size=(shared_num * shared_len,)).cuda()
    batch.init_prefill_batch(input_ids, shared_num, shared_len)
    _ = prefill(model_runner, batch)
    if tp_rank == 0:
        print(f"prefill: {(time.time() - prefill_start) * 1000:.2f} ms")

    extend_start = time.time()
    input_ids = torch.randint(low=5, high=100, size=(unique_num * unique_len,)).cuda()
    batch.update_extend(
        input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices
    )
    predict_ids = extend(model_runner, batch)
    if tp_rank == 0:
        print(f"extend: {(time.time() - extend_start) * 1000:.2f} ms")

    for i in range(decode_len):
        decode_start = time.time()
        predict_ids = torch.from_numpy(predict_ids).cuda()
        batch.update_decode(predict_ids, unique_num)
        predict_ids = decode(model_runner, batch)
        if tp_rank == 0:
            print(f"decode {i}: {(time.time() - decode_start) * 1000:.2f} ms")


def bench_generate(
    model_path,
    tp_size,
    shared_num,
    unique_num,
    shared_len,
    unique_len,
    decode_len,
Liangsheng Yin's avatar
Liangsheng Yin committed
231
    server_args_dict,
Lianmin Zheng's avatar
Lianmin Zheng committed
232
233
234
235
236
237
238
239
):
    print(
        f"tp_size: {tp_size}, "
        f"shared_num: {shared_num}, "
        f"unique_num: {unique_num}, "
        f"shared_len: {shared_len}, "
        f"unique_len: {unique_len}, "
        f"decode_len: {decode_len}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
240
        f"server_args: {server_args_dict}"
Lianmin Zheng's avatar
Lianmin Zheng committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    )
    workers = []
    for tp_rank in range(tp_size):
        proc = mp.Process(
            target=bench_generate_worker,
            args=(
                model_path,
                tp_rank,
                tp_size,
                shared_num,
                unique_num,
                shared_len,
                unique_len,
                decode_len,
Liangsheng Yin's avatar
Liangsheng Yin committed
255
                server_args_dict,
Lianmin Zheng's avatar
Lianmin Zheng committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
            ),
        )
        proc.start()
        workers.append(proc)

    for proc in workers:
        proc.join()


if __name__ == "__main__":
    bench_generate(
        model_path="meta-llama/Llama-2-7b-chat-hf",
        tp_size=1,
        shared_num=1,
        unique_num=32,
        shared_len=256,
        unique_len=256,
        decode_len=8,
Liangsheng Yin's avatar
Liangsheng Yin committed
274
        server_args_dict={},
Lianmin Zheng's avatar
Lianmin Zheng committed
275
    )