tp_worker.py 30.7 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
import asyncio
import logging
import time
Lianmin Zheng's avatar
Lianmin Zheng committed
4
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
5
from concurrent.futures import ThreadPoolExecutor
6
from typing import List
Lianmin Zheng's avatar
Lianmin Zheng committed
7
8
9
10

import rpyc
import torch
from rpyc.utils.classic import obtain
Liangsheng Yin's avatar
Liangsheng Yin committed
11

Liangsheng Yin's avatar
Liangsheng Yin committed
12
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
13
from sglang.srt.constrained.fsm_cache import FSMCache
14
from sglang.srt.constrained.jump_forward import JumpForwardCache
Lianmin Zheng's avatar
Lianmin Zheng committed
15
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
Liangsheng Yin's avatar
Liangsheng Yin committed
16
from sglang.srt.managers.io_struct import (
17
    AbortReq,
Liangsheng Yin's avatar
Liangsheng Yin committed
18
19
    BatchTokenIDOut,
    FlushCacheReq,
20
    TokenizedGenerateReqInput,
Liangsheng Yin's avatar
Liangsheng Yin committed
21
)
22
23
24
25
from sglang.srt.managers.controller.infer_batch import Batch, FinishReason, ForwardMode, Req
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
Lianmin Zheng's avatar
Lianmin Zheng committed
26
from sglang.srt.model_config import ModelConfig
27
from sglang.srt.server_args import ModelPortArgs, ServerArgs
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29
30
31
from sglang.srt.utils import (
    get_int_token_logit_bias,
    is_multimodal_model,
    set_random_seed,
32
33
    start_rpyc_process,
    suppress_other_loggers,
Lianmin Zheng's avatar
Lianmin Zheng committed
34
)
35
36
from sglang.utils import get_exception_traceback

37
logger = logging.getLogger("srt.model_tp")
Lianmin Zheng's avatar
Lianmin Zheng committed
38
39


40
class ModelTpServer:
41
    def __init__(
Lianmin Zheng's avatar
Lianmin Zheng committed
42
        self,
43
        gpu_id: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
44
45
        tp_rank: int,
        server_args: ServerArgs,
46
47
        model_port_args: ModelPortArgs,
        model_overide_args,
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    ):
49
50
        server_args, model_port_args = obtain(server_args), obtain(model_port_args)
        suppress_other_loggers()
Lianmin Zheng's avatar
Lianmin Zheng committed
51
52

        # Copy arguments
53
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
56
        self.dp_size = server_args.dp_size
Lianmin Zheng's avatar
Lianmin Zheng committed
57
        self.schedule_heuristic = server_args.schedule_heuristic
58
        self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
Lianmin Zheng's avatar
Lianmin Zheng committed
59
60
61

        # Init model and tokenizer
        self.model_config = ModelConfig(
Liangsheng Yin's avatar
Liangsheng Yin committed
62
63
64
            server_args.model_path,
            server_args.trust_remote_code,
            context_length=server_args.context_length,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
65
            model_overide_args=model_overide_args,
Lianmin Zheng's avatar
Lianmin Zheng committed
66
67
        )
        self.model_runner = ModelRunner(
Liangsheng Yin's avatar
Liangsheng Yin committed
68
69
            model_config=self.model_config,
            mem_fraction_static=server_args.mem_fraction_static,
70
            gpu_id=gpu_id,
Liangsheng Yin's avatar
Liangsheng Yin committed
71
72
            tp_rank=tp_rank,
            tp_size=server_args.tp_size,
73
            nccl_port=model_port_args.nccl_port,
Lianmin Zheng's avatar
Lianmin Zheng committed
74
            server_args=server_args,
Lianmin Zheng's avatar
Lianmin Zheng committed
75
        )
76

Lianmin Zheng's avatar
Lianmin Zheng committed
77
78
79
80
81
82
83
84
85
86
87
88
89
        if is_multimodal_model(server_args.model_path):
            self.processor = get_processor(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
            self.tokenizer = self.processor.tokenizer
        else:
            self.tokenizer = get_tokenizer(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
90
91
        self.max_total_num_tokens = self.model_runner.max_total_num_tokens
        self.max_prefill_tokens = max(
92
            self.model_config.context_len,
93
            (
94
                min(self.max_total_num_tokens // 6, 65536)
95
96
                if server_args.max_prefill_tokens is None
                else server_args.max_prefill_tokens
97
            ),
Lianmin Zheng's avatar
Lianmin Zheng committed
98
        )
99
100
        self.max_running_requests = (self.max_total_num_tokens // 2
            if server_args.max_running_requests is None else server_args.max_running_requests)
Lianmin Zheng's avatar
Lianmin Zheng committed
101
102
103
        self.int_token_logit_bias = torch.tensor(
            get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
        )
104
        set_random_seed(server_args.random_seed)
105
106

        # Print info
107
108
        logger.info(
            f"[gpu_id={self.gpu_id}] "
109
110
            f"max_total_num_tokens={self.max_total_num_tokens}, "
            f"max_prefill_tokens={self.max_prefill_tokens}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
111
112
            f"context_len={self.model_config.context_len}, "
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
113
        if self.tp_rank == 0:
114
115
116
117
            logger.info(
                f"[gpu_id={self.gpu_id}] "
                f"server_args: {server_args.print_mode_args()}"
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
118
119

        # Init cache
Liangsheng Yin's avatar
Liangsheng Yin committed
120
121
122
123
124
        self.tree_cache = RadixCache(
            req_to_token_pool=self.model_runner.req_to_token_pool,
            token_to_kv_pool=self.model_runner.token_to_kv_pool,
            disable=server_args.disable_radix_cache,
        )
Cody Yu's avatar
Cody Yu committed
125
        self.tree_cache_metrics = {"total": 0, "hit": 0}
126
        self.scheduler = ScheduleHeuristic(
Lianmin Zheng's avatar
Lianmin Zheng committed
127
            self.schedule_heuristic,
128
129
130
            self.max_running_requests,
            self.max_prefill_tokens,
            self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
131
132
133
134
135
136
137
138
139
140
            self.tree_cache,
        )
        self.req_to_token_pool = self.model_runner.req_to_token_pool
        self.token_to_kv_pool = self.model_runner.token_to_kv_pool

        # Init running status
        self.forward_queue: List[Req] = []
        self.running_batch: Batch = None
        self.out_pyobjs = []
        self.decode_forward_ct = 0
141
        self.stream_interval = server_args.stream_interval
Lianmin Zheng's avatar
Lianmin Zheng committed
142
143
        self.num_generated_tokens = 0
        self.last_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
144
145

        # Init the FSM cache for constrained generation
146
147
148
149
150
151
152
        self.regex_fsm_cache = FSMCache(
            server_args.tokenizer_path,
            {
                "tokenizer_mode": server_args.tokenizer_mode,
                "trust_remote_code": server_args.trust_remote_code,
            },
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
153
        self.jump_forward_cache = JumpForwardCache()
Lianmin Zheng's avatar
Lianmin Zheng committed
154

155
        # Init new token estimation
Liangsheng Yin's avatar
Liangsheng Yin committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
        self.new_token_ratio = min(
            global_config.base_new_token_ratio * server_args.schedule_conservativeness,
            1.0,
        )
        self.min_new_token_ratio = min(
            global_config.base_min_new_token_ratio
            * server_args.schedule_conservativeness,
            1.0,
        )
        self.new_token_ratio_decay = global_config.new_token_ratio_decay
        self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
170

Lianmin Zheng's avatar
Lianmin Zheng committed
171
    def exposed_step(self, recv_reqs):
172
        if self.tp_size * self.dp_size != 1:
Lianmin Zheng's avatar
Lianmin Zheng committed
173
174
175
176
177
178
179
            recv_reqs = obtain(recv_reqs)

        try:
            # Recv requests
            for recv_req in recv_reqs:
                if isinstance(recv_req, TokenizedGenerateReqInput):
                    self.handle_generate_request(recv_req)
Liangsheng Yin's avatar
Liangsheng Yin committed
180
181
                elif isinstance(recv_req, FlushCacheReq):
                    self.flush_cache()
182
183
                elif isinstance(recv_req, AbortReq):
                    self.abort_request(recv_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
184
185
186
187
188
189
                else:
                    raise ValueError(f"Invalid request: {recv_req}")

            # Forward
            self.forward_step()
        except Exception:
190
            logger.error("Exception in ModelTpClient:\n" + get_exception_traceback())
Lianmin Zheng's avatar
Lianmin Zheng committed
191
192
193
194
195
196
197
198
199
200
201

        # Return results
        ret = self.out_pyobjs
        self.out_pyobjs = []
        return ret

    @torch.inference_mode()
    def forward_step(self):
        new_batch = self.get_new_fill_batch()

        if new_batch is not None:
202
            # Run a new fill batch
Lianmin Zheng's avatar
Lianmin Zheng committed
203
            self.forward_fill_batch(new_batch)
Liangsheng Yin's avatar
Liangsheng Yin committed
204
205
            self.cache_filled_batch(new_batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
206
207
208
209
210
211
212
213
214
215
            if not new_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = new_batch
                else:
                    self.running_batch.merge(new_batch)
        else:
            # Run decode batch
            if self.running_batch is not None:
                # Run a few decode batches continuously for reducing overhead
                for _ in range(10):
Lianmin Zheng's avatar
Lianmin Zheng committed
216
                    self.num_generated_tokens += len(self.running_batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
217
218
                    self.forward_decode_batch(self.running_batch)

219
220
                    # Print stats
                    if self.tp_rank == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
221
                        if self.decode_forward_ct % 40 == 0:
222
                            num_used = self.max_total_num_tokens - (
Lianmin Zheng's avatar
Lianmin Zheng committed
223
224
225
                                self.token_to_kv_pool.available_size()
                                + self.tree_cache.evictable_size()
                            )
226
                            throughput = self.num_generated_tokens / (
Liangsheng Yin's avatar
Liangsheng Yin committed
227
228
                                time.time() - self.last_stats_tic
                            )
Lianmin Zheng's avatar
Lianmin Zheng committed
229
230
                            self.num_generated_tokens = 0
                            self.last_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
231
                            logger.info(
232
                                f"[gpu_id={self.gpu_id}] Decode batch. "
Lianmin Zheng's avatar
Lianmin Zheng committed
233
234
                                f"#running-req: {len(self.running_batch.reqs)}, "
                                f"#token: {num_used}, "
235
                                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
236
                                f"gen throughput (token/s): {throughput:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
237
238
                                f"#queue-req: {len(self.forward_queue)}"
                            )
239
240
241
242
243
244
245

                    if self.running_batch.is_empty():
                        self.running_batch = None
                        break

                    if self.out_pyobjs and self.running_batch.reqs[0].stream:
                        break
246
            else:
247
                # Check the available size
248
249
250
251
                available_size = (
                    self.token_to_kv_pool.available_size()
                    + self.tree_cache.evictable_size()
                )
252
                if available_size != self.max_total_num_tokens:
Ying Sheng's avatar
Ying Sheng committed
253
                    warnings.warn(
254
                        "Warning: "
255
                        f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
256
257
                        "KV cache pool leak detected!"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
258
259
260
261
262

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
263
        req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
264
265
        req.pixel_values = recv_req.pixel_values
        if req.pixel_values is not None:
266
            req.pad_value = [
Lianmin Zheng's avatar
Lianmin Zheng committed
267
268
269
270
271
                (recv_req.image_hash) % self.model_config.vocab_size,
                (recv_req.image_hash >> 16) % self.model_config.vocab_size,
                (recv_req.image_hash >> 32) % self.model_config.vocab_size,
                (recv_req.image_hash >> 64) % self.model_config.vocab_size,
            ]
Lianmin Zheng's avatar
Lianmin Zheng committed
272
            req.image_size = recv_req.image_size
Liangsheng Yin's avatar
Liangsheng Yin committed
273
274
275
276
277
278
279
            req.origin_input_ids, req.image_offset = (
                self.model_runner.model.pad_input_ids(
                    req.origin_input_ids_unpadded,
                    req.pad_value,
                    req.pixel_values.shape,
                    req.image_size,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
280
281
            )
        req.sampling_params = recv_req.sampling_params
282
283
        req.return_logprob = recv_req.return_logprob
        req.logprob_start_len = recv_req.logprob_start_len
Liangsheng Yin's avatar
Liangsheng Yin committed
284
        req.top_logprobs_num = recv_req.top_logprobs_num
Lianmin Zheng's avatar
Lianmin Zheng committed
285
286
287
        req.stream = recv_req.stream
        req.tokenizer = self.tokenizer

288
289
        # Init regex fsm
        if req.sampling_params.regex is not None:
Cody Yu's avatar
Cody Yu committed
290
            req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
291
            if not self.disable_regex_jump_forward:
Liangsheng Yin's avatar
Liangsheng Yin committed
292
                req.jump_forward_map = self.jump_forward_cache.query(
Liangsheng Yin's avatar
Liangsheng Yin committed
293
294
                    req.sampling_params.regex
                )
295

296
        # Truncate prompts that are too long
Liangsheng Yin's avatar
Liangsheng Yin committed
297
        req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
Lianmin Zheng's avatar
Lianmin Zheng committed
298
299
        req.sampling_params.max_new_tokens = min(
            req.sampling_params.max_new_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
300
            self.model_config.context_len - 1 - len(req.origin_input_ids),
301
            self.max_total_num_tokens - 128 - len(req.origin_input_ids),
Lianmin Zheng's avatar
Lianmin Zheng committed
302
303
304
305
306
307
        )
        self.forward_queue.append(req)

    def get_new_fill_batch(self):
        if (
            self.running_batch is not None
308
            and len(self.running_batch.reqs) > self.max_running_requests
Lianmin Zheng's avatar
Lianmin Zheng committed
309
310
311
        ):
            return None

312
        # Compute matched prefix length
Lianmin Zheng's avatar
Lianmin Zheng committed
313
        for req in self.forward_queue:
Liangsheng Yin's avatar
Liangsheng Yin committed
314
315
316
317
            assert (
                len(req.output_ids) == 0
            ), "The output ids should be empty when prefilling"
            req.input_ids = req.origin_input_ids + req.prev_output_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
318
            prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
319
320
321
            if req.return_logprob:
                prefix_indices = prefix_indices[: req.logprob_start_len]
            req.extend_input_len = len(req.input_ids) - len(prefix_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
            req.prefix_indices = prefix_indices
            req.last_node = last_node

        # Get priority queue
        self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)

        # Add requests if there is available space
        can_run_list = []
        new_batch_total_tokens = 0
        new_batch_input_tokens = 0

        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if self.running_batch:
            available_size -= sum(
                [
339
                    (r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
340
341
342
343
344
                    for r in self.running_batch.reqs
                ]
            )

        for req in self.forward_queue:
Liangsheng Yin's avatar
Liangsheng Yin committed
345
            if req.return_logprob and req.normalized_prompt_logprob is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
346
                # Need at least two tokens to compute normalized logprob
347
348
349
                if req.extend_input_len < 2:
                    delta = 2 - req.extend_input_len
                    req.extend_input_len += delta
Lianmin Zheng's avatar
Lianmin Zheng committed
350
351
352
                    req.prefix_indices = req.prefix_indices[:-delta]
                    if req.image_offset is not None:
                        req.image_offset += delta
353
            if req.extend_input_len == 0 and req.max_new_tokens() > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
354
                # Need at least one token to compute logits
355
                req.extend_input_len = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
356
357
358
359
360
                req.prefix_indices = req.prefix_indices[:-1]
                if req.image_offset is not None:
                    req.image_offset += 1

            if (
361
                req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
362
                < available_size
363
                and req.extend_input_len + new_batch_input_tokens
364
                < self.max_prefill_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
365
            ):
Liangsheng Yin's avatar
Liangsheng Yin committed
366
                delta = self.tree_cache.inc_lock_ref(req.last_node)
Lianmin Zheng's avatar
Lianmin Zheng committed
367
368
369
                available_size += delta

                if not (
370
                    req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
371
372
                    < available_size
                ):
Liangsheng Yin's avatar
Liangsheng Yin committed
373
374
                    # Undo locking
                    delta = self.tree_cache.dec_lock_ref(req.last_node)
Lianmin Zheng's avatar
Lianmin Zheng committed
375
                    available_size += delta
376
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
377
                else:
378
                    # Add this request to the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
379
380
                    can_run_list.append(req)
                    new_batch_total_tokens += (
381
                        req.extend_input_len + req.max_new_tokens()
Lianmin Zheng's avatar
Lianmin Zheng committed
382
                    )
383
                    new_batch_input_tokens += req.extend_input_len
384
385
            else:
                break
Lianmin Zheng's avatar
Lianmin Zheng committed
386
387
388
        if len(can_run_list) == 0:
            return None

389
        # Print stats
Lianmin Zheng's avatar
Lianmin Zheng committed
390
        if self.tp_rank == 0:
391
392
393
            running_req = (
                0 if self.running_batch is None else len(self.running_batch.reqs)
            )
Cody Yu's avatar
Cody Yu committed
394
            hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
395
396
397
            self.tree_cache_metrics["total"] += (
                hit_tokens + new_batch_input_tokens
            ) / 10**9
Cody Yu's avatar
Cody Yu committed
398
399
400
401
            self.tree_cache_metrics["hit"] += hit_tokens / 10**9
            tree_cache_hit_rate = (
                self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
402
            logger.info(
403
404
405
406
407
408
409
                f"[gpu_id={self.gpu_id}] Prefil batch. "
                f"#new-seq: {len(can_run_list)}, "
                f"#new-token: {new_batch_input_tokens}, "
                f"#cached-token: {hit_tokens}, "
                f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
                f"#running-req: {running_req}, "
                f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
Cody Yu's avatar
Cody Yu committed
410
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
411
            # logger.debug(
412
413
414
415
            #    f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
            #    f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
            #    f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
            #    f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
Liangsheng Yin's avatar
Liangsheng Yin committed
416
            # )
Lianmin Zheng's avatar
Lianmin Zheng committed
417

418
        # Return the new batch
419
        new_batch = Batch.init_new(
Lianmin Zheng's avatar
Lianmin Zheng committed
420
421
422
423
424
425
426
427
428
429
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
        )
        self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
        return new_batch

    def forward_fill_batch(self, batch: Batch):
        # Build batch tensors
430
431
432
433
        batch.prepare_for_extend(
            self.model_config.vocab_size, self.int_token_logit_bias
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
434
435
        if batch.extend_num_tokens != 0:
            # Forward
436
            logits, (
Liangsheng Yin's avatar
Liangsheng Yin committed
437
                prefill_token_logprobs,
438
                normalized_prompt_logprobs,
Liangsheng Yin's avatar
Liangsheng Yin committed
439
440
                prefill_top_logprobs,
                decode_top_logprobs,
441
                last_logprobs,
442
            ) = self.model_runner.forward(batch, ForwardMode.EXTEND)
Liangsheng Yin's avatar
Liangsheng Yin committed
443
            if prefill_token_logprobs is not None:
444
445
                prefill_token_logprobs = prefill_token_logprobs.tolist()
                normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
446

Cody Yu's avatar
Cody Yu committed
447
            next_token_ids, _ = batch.sample(logits)
448
449
450

            # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
            if last_logprobs is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
451
452
453
454
                last_token_logprobs = last_logprobs[
                    torch.arange(len(batch.reqs), device=next_token_ids.device),
                    next_token_ids,
                ].tolist()
455
456

            next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
457
458
        else:
            next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
Cody Yu's avatar
Cody Yu committed
459
460

        # Check finish condition
461
        pt = 0
462
        for i, req in enumerate(batch.reqs):
463
            req.completion_tokens_wo_jump_forward += 1
464
465
            req.output_ids = [next_token_ids[i]]
            req.check_finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
466

467
            if req.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
468
469
470
471
472
473
474
475
476
477
                if req.normalized_prompt_logprob is None:
                    req.normalized_prompt_logprob = normalized_prompt_logprobs[i]

                if req.prefill_token_logprobs is None:
                    # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
                    req.prefill_token_logprobs = list(
                        zip(
                            prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
                            req.input_ids[-req.extend_input_len + 1 :],
                        )
Liangsheng Yin's avatar
Liangsheng Yin committed
478
                    )
Liangsheng Yin's avatar
Liangsheng Yin committed
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
                    if req.logprob_start_len == 0:
                        req.prefill_token_logprobs = [
                            (None, req.input_ids[0])
                        ] + req.prefill_token_logprobs

                if req.last_update_decode_tokens != 0:
                    req.decode_token_logprobs.extend(
                        list(
                            zip(
                                prefill_token_logprobs[
                                    pt
                                    + req.extend_input_len
                                    - req.last_update_decode_tokens : pt
                                    + req.extend_input_len
                                    - 1
                                ],
                                req.input_ids[-req.last_update_decode_tokens + 1 :],
                            )
                        )
                    )

                req.decode_token_logprobs.append(
Liangsheng Yin's avatar
Liangsheng Yin committed
501
                    (last_token_logprobs[i], next_token_ids[i])
Liangsheng Yin's avatar
Liangsheng Yin committed
502
                )
503
504

            if req.top_logprobs_num > 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
505
506
507
508
509
510
511
512
513
514
                if req.prefill_top_logprobs is None:
                    req.prefill_top_logprobs = prefill_top_logprobs[i]
                    if req.logprob_start_len == 0:
                        req.prefill_top_logprobs = [None] + req.prefill_top_logprobs

                if req.last_update_decode_tokens != 0:
                    req.decode_top_logprobs.extend(
                        prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
                    )
                req.decode_top_logprobs.append(decode_top_logprobs[i])
515
516

            pt += req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
517
518
519

        self.handle_finished_requests(batch)

Liangsheng Yin's avatar
Liangsheng Yin committed
520
    def cache_filled_batch(self, batch: Batch):
521
        req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
Liangsheng Yin's avatar
Liangsheng Yin committed
522
523
524
525
526
527
528
529
530
531
        for i, req in enumerate(batch.reqs):
            new_prefix_indices, new_last_node = self.tree_cache.cache_req(
                token_ids=tuple(req.input_ids + req.output_ids)[:-1],
                last_uncached_pos=len(req.prefix_indices),
                req_pool_idx=req_pool_indices_cpu[i],
                del_in_memory_pool=False,
                old_last_node=req.last_node,
            )
            req.prefix_indices, req.last_node = new_prefix_indices, new_last_node

Lianmin Zheng's avatar
Lianmin Zheng committed
532
    def forward_decode_batch(self, batch: Batch):
533
534
535
        # check if decode out of memory
        if not batch.check_decode_mem():
            old_ratio = self.new_token_ratio
Liangsheng Yin's avatar
Liangsheng Yin committed
536
            self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
537
538
539
540
541
542
543
544
545
546

            retracted_reqs = batch.retract_decode()
            logger.info(
                "decode out of memory happened, "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.forward_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
Liangsheng Yin's avatar
Liangsheng Yin committed
547
                self.new_token_ratio - self.new_token_ratio_decay,
548
549
550
                self.min_new_token_ratio,
            )

551
        if not self.disable_regex_jump_forward:
Liangsheng Yin's avatar
Liangsheng Yin committed
552
            # check for jump-forward
Liangsheng Yin's avatar
Liangsheng Yin committed
553
            jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
554

Liangsheng Yin's avatar
Liangsheng Yin committed
555
            self.forward_queue.extend(jump_forward_reqs)
Liangsheng Yin's avatar
Liangsheng Yin committed
556
557
558
            if batch.is_empty():
                return

Lianmin Zheng's avatar
Lianmin Zheng committed
559
        # Update batch tensors
560
        self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
561
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
562
563

        # Forward
564
565
566
567
        logits, (
            _,
            _,
            _,
568
            decode_top_logprobs,
569
570
            last_logprobs,
        ) = self.model_runner.forward(batch, ForwardMode.DECODE)
Cody Yu's avatar
Cody Yu committed
571
        next_token_ids, _ = batch.sample(logits)
572
        next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
573

Cody Yu's avatar
Cody Yu committed
574
575
        # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
        if last_logprobs is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
576
            new_token_logprobs = last_logprobs[
577
                torch.arange(len(batch.reqs)), next_token_ids
578
            ].tolist()
Cody Yu's avatar
Cody Yu committed
579
580

        # Check finish condition
581
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
582
            req.completion_tokens_wo_jump_forward += 1
Liangsheng Yin's avatar
Liangsheng Yin committed
583
            req.output_ids.append(next_token_id)
Cody Yu's avatar
Cody Yu committed
584
585
            req.check_finished()

586
            if req.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
587
                req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
588
589

            if req.top_logprobs_num > 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
590
                req.decode_top_logprobs.append(decode_top_logprobs[i])
Lianmin Zheng's avatar
Lianmin Zheng committed
591
592
593
594
595

        self.handle_finished_requests(batch)

    def handle_finished_requests(self, batch: Batch):
        output_rids = []
Liangsheng Yin's avatar
Liangsheng Yin committed
596
        prev_output_strs = []
Lianmin Zheng's avatar
Lianmin Zheng committed
597
598
599
        output_tokens = []
        output_hit_stop_str = []
        output_skip_special_tokens = []
600
        output_spaces_between_special_tokens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
601
602
603
604
605
606
607
608
609
610
611
        output_meta_info = []
        output_finished = []
        finished_indices = []
        unfinished_indices = []
        for i, req in enumerate(batch.reqs):
            if req.finished:
                finished_indices.append(i)
            else:
                unfinished_indices.append(i)

            if req.finished or (
612
613
614
615
616
617
618
                (
                    req.stream
                    and (
                        self.decode_forward_ct % self.stream_interval == 0
                        or len(req.output_ids) == 1
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
619
620
            ):
                output_rids.append(req.rid)
Liangsheng Yin's avatar
Liangsheng Yin committed
621
                prev_output_strs.append(req.prev_output_str)
Lianmin Zheng's avatar
Lianmin Zheng committed
622
623
624
625
626
                output_tokens.append(req.output_ids)
                output_hit_stop_str.append(req.hit_stop_str)
                output_skip_special_tokens.append(
                    req.sampling_params.skip_special_tokens
                )
627
628
629
                output_spaces_between_special_tokens.append(
                    req.sampling_params.spaces_between_special_tokens
                )
630

Lianmin Zheng's avatar
Lianmin Zheng committed
631
                meta_info = {
Liangsheng Yin's avatar
Liangsheng Yin committed
632
633
                    "prompt_tokens": len(req.origin_input_ids),
                    "completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
Liangsheng Yin's avatar
Liangsheng Yin committed
634
                    "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
Lianmin Zheng's avatar
Lianmin Zheng committed
635
                    "finish_reason": FinishReason.to_str(req.finish_reason),
Lianmin Zheng's avatar
Lianmin Zheng committed
636
                    "hit_stop_str": req.hit_stop_str,
Lianmin Zheng's avatar
Lianmin Zheng committed
637
                }
638
                if req.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
639
640
641
642
643
644
645
646
647
648
649
650
651
                    (
                        meta_info["prefill_token_logprobs"],
                        meta_info["decode_token_logprobs"],
                        meta_info["prefill_top_logprobs"],
                        meta_info["decode_top_logprobs"],
                        meta_info["normalized_prompt_logprob"],
                    ) = (
                        req.prefill_token_logprobs,
                        req.decode_token_logprobs,
                        req.prefill_top_logprobs,
                        req.decode_top_logprobs,
                        req.normalized_prompt_logprob,
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
652
653
654
655
656
657
658
659
                output_meta_info.append(meta_info)
                output_finished.append(req.finished)

        # Send to detokenizer
        if output_rids:
            self.out_pyobjs.append(
                BatchTokenIDOut(
                    output_rids,
Liangsheng Yin's avatar
Liangsheng Yin committed
660
                    prev_output_strs,
Lianmin Zheng's avatar
Lianmin Zheng committed
661
662
663
                    output_tokens,
                    output_hit_stop_str,
                    output_skip_special_tokens,
664
                    output_spaces_between_special_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
665
666
667
668
669
670
671
672
                    output_meta_info,
                    output_finished,
                )
            )

        # Remove finished reqs
        if finished_indices:
            # Update radix cache
673
            req_pool_indices_cpu = batch.req_pool_indices.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
674
675
            for i in finished_indices:
                req = batch.reqs[i]
Liangsheng Yin's avatar
Liangsheng Yin committed
676
677
678
679
                self.tree_cache.cache_req(
                    token_ids=tuple(req.input_ids + req.output_ids)[:-1],
                    last_uncached_pos=len(req.prefix_indices),
                    req_pool_idx=req_pool_indices_cpu[i],
680
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
681

Liangsheng Yin's avatar
Liangsheng Yin committed
682
                self.tree_cache.dec_lock_ref(req.last_node)
Lianmin Zheng's avatar
Lianmin Zheng committed
683
684
685
686
687
688
689

            # Update batch tensors
            if unfinished_indices:
                batch.filter_batch(unfinished_indices)
            else:
                batch.reqs = []

690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
    def flush_cache(self):
        if len(self.forward_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
            self.regex_fsm_cache.reset()
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
        else:
            warnings.warn(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.forward_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )

    def abort_request(self, recv_req):
709
        # Delete requests in the waiting queue
710
711
712
713
714
715
716
717
718
        to_del = None
        for i, req in enumerate(self.forward_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.forward_queue[to_del]

719
720
721
722
723
724
725
726
        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
                if req.rid == recv_req.rid:
                    req.finished = True
                    req.finish_reason = FinishReason.ABORT
                    break

Lianmin Zheng's avatar
Lianmin Zheng committed
727

728
729
class ModelTpService(rpyc.Service):
    exposed_ModelTpServer = ModelTpServer
730
731


732
class ModelTpClient:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
733
    def __init__(
734
735
736
737
738
        self,
        gpu_ids: List[int],
        server_args: ServerArgs,
        model_port_args: ModelPortArgs,
        model_overide_args,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
739
    ):
740
741
        server_args, model_port_args = obtain(server_args), obtain(model_port_args)
        self.tp_size = server_args.tp_size
Lianmin Zheng's avatar
Lianmin Zheng committed
742

743
        if self.tp_size * server_args.dp_size == 1:
Lianmin Zheng's avatar
Lianmin Zheng committed
744
            # Init model
745
746
747
748
749
750
751
            assert len(gpu_ids) == 1
            self.model_server = ModelTpService().exposed_ModelTpServer(
                0,
                gpu_ids[0],
                server_args,
                model_port_args,
                model_overide_args,
752
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
753
754
755
756
757
758
759
760
761
762

            # Wrap functions
            def async_wrap(f):
                async def _func(*args, **kwargs):
                    return f(*args, **kwargs)

                return _func

            self.step = async_wrap(self.model_server.exposed_step)
        else:
763
            with ThreadPoolExecutor(self.tp_size) as executor:
Lianmin Zheng's avatar
Lianmin Zheng committed
764
                # Launch model processes
765
766
767
768
769
                rets = executor.map(
                    lambda args: start_rpyc_process(*args),
                    [(ModelTpService, p) for p in model_port_args.model_tp_ports],
                )
                self.model_services = [x[0] for x in rets]
Lianmin Zheng's avatar
Lianmin Zheng committed
770
771
772
773
                self.procs = [x[1] for x in rets]

                # Init model
                def init_model(i):
774
775
776
777
778
779
                    return self.model_services[i].ModelTpServer(
                        gpu_ids[i],
                        i,
                        server_args,
                        model_port_args,
                        model_overide_args,
780
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
781

782
                self.model_servers = executor.map(init_model, range(self.tp_size))
Lianmin Zheng's avatar
Lianmin Zheng committed
783
784
785
786
787
788
789
790
791
792
793
794

            # Wrap functions
            def async_wrap(func_name):
                fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]

                async def _func(*args, **kwargs):
                    tasks = [f(*args, **kwargs) for f in fs]
                    await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
                    return obtain(tasks[0].value)

                return _func

795
            self.step = async_wrap("step")