tp_worker.py 30.7 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
"""A tensor parallel worker."""

Lianmin Zheng's avatar
Lianmin Zheng committed
3
4
5
import asyncio
import logging
import time
Lianmin Zheng's avatar
Lianmin Zheng committed
6
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
7
from concurrent.futures import ThreadPoolExecutor
8
from typing import List
Lianmin Zheng's avatar
Lianmin Zheng committed
9
10
11
12

import rpyc
import torch
from rpyc.utils.classic import obtain
Liangsheng Yin's avatar
Liangsheng Yin committed
13

Liangsheng Yin's avatar
Liangsheng Yin committed
14
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
15
from sglang.srt.constrained.fsm_cache import FSMCache
16
from sglang.srt.constrained.jump_forward import JumpForwardCache
Lianmin Zheng's avatar
Lianmin Zheng committed
17
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
Liangsheng Yin's avatar
Liangsheng Yin committed
18
from sglang.srt.managers.io_struct import (
19
    AbortReq,
Liangsheng Yin's avatar
Liangsheng Yin committed
20
21
    BatchTokenIDOut,
    FlushCacheReq,
22
    TokenizedGenerateReqInput,
Liangsheng Yin's avatar
Liangsheng Yin committed
23
)
Liangsheng Yin's avatar
Liangsheng Yin committed
24
25
26
27
28
29
30
from sglang.srt.managers.controller.infer_batch import (
    BaseFinishReason,
    Batch,
    FINISH_ABORT,
    ForwardMode,
    Req,
)
31
32
33
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
Lianmin Zheng's avatar
Lianmin Zheng committed
34
from sglang.srt.model_config import ModelConfig
35
from sglang.srt.server_args import ModelPortArgs, ServerArgs
Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
38
39
from sglang.srt.utils import (
    get_int_token_logit_bias,
    is_multimodal_model,
    set_random_seed,
40
41
    start_rpyc_process,
    suppress_other_loggers,
Lianmin Zheng's avatar
Lianmin Zheng committed
42
)
43
44
from sglang.utils import get_exception_traceback

45
logger = logging.getLogger("srt.tp_worker")
Lianmin Zheng's avatar
Lianmin Zheng committed
46
47


48
class ModelTpServer:
49
    def __init__(
Lianmin Zheng's avatar
Lianmin Zheng committed
50
        self,
51
        gpu_id: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
52
53
        tp_rank: int,
        server_args: ServerArgs,
54
55
        model_port_args: ModelPortArgs,
        model_overide_args,
Lianmin Zheng's avatar
Lianmin Zheng committed
56
    ):
57
58
        server_args, model_port_args = obtain(server_args), obtain(model_port_args)
        suppress_other_loggers()
Lianmin Zheng's avatar
Lianmin Zheng committed
59
60

        # Copy arguments
61
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
62
63
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
64
        self.dp_size = server_args.dp_size
Lianmin Zheng's avatar
Lianmin Zheng committed
65
        self.schedule_heuristic = server_args.schedule_heuristic
66
        self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
69

        # Init model and tokenizer
        self.model_config = ModelConfig(
Liangsheng Yin's avatar
Liangsheng Yin committed
70
71
72
            server_args.model_path,
            server_args.trust_remote_code,
            context_length=server_args.context_length,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
73
            model_overide_args=model_overide_args,
Lianmin Zheng's avatar
Lianmin Zheng committed
74
75
        )
        self.model_runner = ModelRunner(
Liangsheng Yin's avatar
Liangsheng Yin committed
76
77
            model_config=self.model_config,
            mem_fraction_static=server_args.mem_fraction_static,
78
            gpu_id=gpu_id,
Liangsheng Yin's avatar
Liangsheng Yin committed
79
80
            tp_rank=tp_rank,
            tp_size=server_args.tp_size,
81
            nccl_port=model_port_args.nccl_port,
Lianmin Zheng's avatar
Lianmin Zheng committed
82
            server_args=server_args,
Lianmin Zheng's avatar
Lianmin Zheng committed
83
        )
84

Lianmin Zheng's avatar
Lianmin Zheng committed
85
86
87
88
89
90
91
92
93
94
95
96
97
        if is_multimodal_model(server_args.model_path):
            self.processor = get_processor(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
            self.tokenizer = self.processor.tokenizer
        else:
            self.tokenizer = get_tokenizer(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
98
99
        self.max_total_num_tokens = self.model_runner.max_total_num_tokens
        self.max_prefill_tokens = max(
100
            self.model_config.context_len,
101
            (
102
                min(self.max_total_num_tokens // 6, 65536)
103
104
                if server_args.max_prefill_tokens is None
                else server_args.max_prefill_tokens
105
            ),
Lianmin Zheng's avatar
Lianmin Zheng committed
106
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
107
108
109
110
111
        self.max_running_requests = (
            self.max_total_num_tokens // 2
            if server_args.max_running_requests is None
            else server_args.max_running_requests
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
112
113
114
        self.int_token_logit_bias = torch.tensor(
            get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
        )
115
        set_random_seed(server_args.random_seed)
116
117

        # Print info
118
119
        logger.info(
            f"[gpu_id={self.gpu_id}] "
120
121
            f"max_total_num_tokens={self.max_total_num_tokens}, "
            f"max_prefill_tokens={self.max_prefill_tokens}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
122
123
            f"context_len={self.model_config.context_len}, "
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
124
        if self.tp_rank == 0:
125
126
127
128
            logger.info(
                f"[gpu_id={self.gpu_id}] "
                f"server_args: {server_args.print_mode_args()}"
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
129
130

        # Init cache
Liangsheng Yin's avatar
Liangsheng Yin committed
131
132
133
134
135
        self.tree_cache = RadixCache(
            req_to_token_pool=self.model_runner.req_to_token_pool,
            token_to_kv_pool=self.model_runner.token_to_kv_pool,
            disable=server_args.disable_radix_cache,
        )
Cody Yu's avatar
Cody Yu committed
136
        self.tree_cache_metrics = {"total": 0, "hit": 0}
137
        self.scheduler = ScheduleHeuristic(
Lianmin Zheng's avatar
Lianmin Zheng committed
138
            self.schedule_heuristic,
139
140
141
            self.max_running_requests,
            self.max_prefill_tokens,
            self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
142
143
144
145
146
147
148
149
150
151
            self.tree_cache,
        )
        self.req_to_token_pool = self.model_runner.req_to_token_pool
        self.token_to_kv_pool = self.model_runner.token_to_kv_pool

        # Init running status
        self.forward_queue: List[Req] = []
        self.running_batch: Batch = None
        self.out_pyobjs = []
        self.decode_forward_ct = 0
152
        self.stream_interval = server_args.stream_interval
Lianmin Zheng's avatar
Lianmin Zheng committed
153
154
        self.num_generated_tokens = 0
        self.last_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
155
156

        # Init the FSM cache for constrained generation
157
158
159
160
161
162
163
        self.regex_fsm_cache = FSMCache(
            server_args.tokenizer_path,
            {
                "tokenizer_mode": server_args.tokenizer_mode,
                "trust_remote_code": server_args.trust_remote_code,
            },
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
164
        self.jump_forward_cache = JumpForwardCache()
Lianmin Zheng's avatar
Lianmin Zheng committed
165

166
        # Init new token estimation
Liangsheng Yin's avatar
Liangsheng Yin committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
        self.new_token_ratio = min(
            global_config.base_new_token_ratio * server_args.schedule_conservativeness,
            1.0,
        )
        self.min_new_token_ratio = min(
            global_config.base_min_new_token_ratio
            * server_args.schedule_conservativeness,
            1.0,
        )
        self.new_token_ratio_decay = global_config.new_token_ratio_decay
        self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
181

Lianmin Zheng's avatar
Lianmin Zheng committed
182
    def exposed_step(self, recv_reqs):
183
        if self.tp_size * self.dp_size != 1:
Lianmin Zheng's avatar
Lianmin Zheng committed
184
185
186
187
188
189
190
            recv_reqs = obtain(recv_reqs)

        try:
            # Recv requests
            for recv_req in recv_reqs:
                if isinstance(recv_req, TokenizedGenerateReqInput):
                    self.handle_generate_request(recv_req)
Liangsheng Yin's avatar
Liangsheng Yin committed
191
192
                elif isinstance(recv_req, FlushCacheReq):
                    self.flush_cache()
193
194
                elif isinstance(recv_req, AbortReq):
                    self.abort_request(recv_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
195
196
197
198
199
200
                else:
                    raise ValueError(f"Invalid request: {recv_req}")

            # Forward
            self.forward_step()
        except Exception:
201
202
            logger.error("Exception in ModelTpServer:\n" + get_exception_traceback())
            raise
Lianmin Zheng's avatar
Lianmin Zheng committed
203
204
205
206
207
208
209
210
211
212
213

        # Return results
        ret = self.out_pyobjs
        self.out_pyobjs = []
        return ret

    @torch.inference_mode()
    def forward_step(self):
        new_batch = self.get_new_fill_batch()

        if new_batch is not None:
214
            # Run a new fill batch
Lianmin Zheng's avatar
Lianmin Zheng committed
215
            self.forward_fill_batch(new_batch)
Liangsheng Yin's avatar
Liangsheng Yin committed
216
217
            self.cache_filled_batch(new_batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
218
219
220
221
222
223
224
225
226
227
            if not new_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = new_batch
                else:
                    self.running_batch.merge(new_batch)
        else:
            # Run decode batch
            if self.running_batch is not None:
                # Run a few decode batches continuously for reducing overhead
                for _ in range(10):
Lianmin Zheng's avatar
Lianmin Zheng committed
228
                    self.num_generated_tokens += len(self.running_batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
229
230
                    self.forward_decode_batch(self.running_batch)

231
232
                    # Print stats
                    if self.tp_rank == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
233
                        if self.decode_forward_ct % 40 == 0:
234
                            num_used = self.max_total_num_tokens - (
Lianmin Zheng's avatar
Lianmin Zheng committed
235
236
237
                                self.token_to_kv_pool.available_size()
                                + self.tree_cache.evictable_size()
                            )
238
                            throughput = self.num_generated_tokens / (
Liangsheng Yin's avatar
Liangsheng Yin committed
239
240
                                time.time() - self.last_stats_tic
                            )
Lianmin Zheng's avatar
Lianmin Zheng committed
241
242
                            self.num_generated_tokens = 0
                            self.last_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
243
                            logger.info(
244
                                f"[gpu_id={self.gpu_id}] Decode batch. "
Lianmin Zheng's avatar
Lianmin Zheng committed
245
246
                                f"#running-req: {len(self.running_batch.reqs)}, "
                                f"#token: {num_used}, "
247
                                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
248
                                f"gen throughput (token/s): {throughput:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
249
250
                                f"#queue-req: {len(self.forward_queue)}"
                            )
251
252
253
254
255
256
257

                    if self.running_batch.is_empty():
                        self.running_batch = None
                        break

                    if self.out_pyobjs and self.running_batch.reqs[0].stream:
                        break
258
            else:
259
                # Check the available size
260
261
262
263
                available_size = (
                    self.token_to_kv_pool.available_size()
                    + self.tree_cache.evictable_size()
                )
264
                if available_size != self.max_total_num_tokens:
Ying Sheng's avatar
Ying Sheng committed
265
                    warnings.warn(
266
                        "Warning: "
267
                        f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
268
269
                        "KV cache pool leak detected!"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
270
271
272
273
274

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
275
        req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
276
277
        req.pixel_values = recv_req.pixel_values
        if req.pixel_values is not None:
278
            req.pad_value = [
Lianmin Zheng's avatar
Lianmin Zheng committed
279
280
281
282
283
                (recv_req.image_hash) % self.model_config.vocab_size,
                (recv_req.image_hash >> 16) % self.model_config.vocab_size,
                (recv_req.image_hash >> 32) % self.model_config.vocab_size,
                (recv_req.image_hash >> 64) % self.model_config.vocab_size,
            ]
Lianmin Zheng's avatar
Lianmin Zheng committed
284
            req.image_size = recv_req.image_size
Liangsheng Yin's avatar
Liangsheng Yin committed
285
286
287
288
289
290
291
            req.origin_input_ids, req.image_offset = (
                self.model_runner.model.pad_input_ids(
                    req.origin_input_ids_unpadded,
                    req.pad_value,
                    req.pixel_values.shape,
                    req.image_size,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
292
293
            )
        req.sampling_params = recv_req.sampling_params
294
295
        req.return_logprob = recv_req.return_logprob
        req.logprob_start_len = recv_req.logprob_start_len
Liangsheng Yin's avatar
Liangsheng Yin committed
296
        req.top_logprobs_num = recv_req.top_logprobs_num
Lianmin Zheng's avatar
Lianmin Zheng committed
297
298
299
        req.stream = recv_req.stream
        req.tokenizer = self.tokenizer

300
301
        # Init regex fsm
        if req.sampling_params.regex is not None:
Cody Yu's avatar
Cody Yu committed
302
            req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
303
            if not self.disable_regex_jump_forward:
Liangsheng Yin's avatar
Liangsheng Yin committed
304
                req.jump_forward_map = self.jump_forward_cache.query(
Liangsheng Yin's avatar
Liangsheng Yin committed
305
306
                    req.sampling_params.regex
                )
307

308
        # Truncate prompts that are too long
Liangsheng Yin's avatar
Liangsheng Yin committed
309
        req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
Lianmin Zheng's avatar
Lianmin Zheng committed
310
311
        req.sampling_params.max_new_tokens = min(
            req.sampling_params.max_new_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
312
            self.model_config.context_len - 1 - len(req.origin_input_ids),
313
            self.max_total_num_tokens - 128 - len(req.origin_input_ids),
Lianmin Zheng's avatar
Lianmin Zheng committed
314
315
316
317
318
319
        )
        self.forward_queue.append(req)

    def get_new_fill_batch(self):
        if (
            self.running_batch is not None
320
            and len(self.running_batch.reqs) > self.max_running_requests
Lianmin Zheng's avatar
Lianmin Zheng committed
321
322
323
        ):
            return None

324
        # Compute matched prefix length
Lianmin Zheng's avatar
Lianmin Zheng committed
325
        for req in self.forward_queue:
Liangsheng Yin's avatar
Liangsheng Yin committed
326
            req.input_ids = req.origin_input_ids + req.output_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
327
            prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
328
329
330
            if req.return_logprob:
                prefix_indices = prefix_indices[: req.logprob_start_len]
            req.extend_input_len = len(req.input_ids) - len(prefix_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
            req.prefix_indices = prefix_indices
            req.last_node = last_node

        # Get priority queue
        self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)

        # Add requests if there is available space
        can_run_list = []
        new_batch_total_tokens = 0
        new_batch_input_tokens = 0

        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if self.running_batch:
            available_size -= sum(
                [
348
                    (r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
349
350
351
352
353
                    for r in self.running_batch.reqs
                ]
            )

        for req in self.forward_queue:
Liangsheng Yin's avatar
Liangsheng Yin committed
354
            if req.return_logprob and req.normalized_prompt_logprob is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
355
                # Need at least two tokens to compute normalized logprob
356
357
358
                if req.extend_input_len < 2:
                    delta = 2 - req.extend_input_len
                    req.extend_input_len += delta
Lianmin Zheng's avatar
Lianmin Zheng committed
359
360
361
                    req.prefix_indices = req.prefix_indices[:-delta]
                    if req.image_offset is not None:
                        req.image_offset += delta
362
            if req.extend_input_len == 0 and req.max_new_tokens() > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
363
                # Need at least one token to compute logits
364
                req.extend_input_len = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
365
366
367
368
369
                req.prefix_indices = req.prefix_indices[:-1]
                if req.image_offset is not None:
                    req.image_offset += 1

            if (
370
                req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
371
                < available_size
372
                and req.extend_input_len + new_batch_input_tokens
373
                < self.max_prefill_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
374
            ):
Liangsheng Yin's avatar
Liangsheng Yin committed
375
                delta = self.tree_cache.inc_lock_ref(req.last_node)
Lianmin Zheng's avatar
Lianmin Zheng committed
376
377
378
                available_size += delta

                if not (
379
                    req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
380
381
                    < available_size
                ):
Liangsheng Yin's avatar
Liangsheng Yin committed
382
383
                    # Undo locking
                    delta = self.tree_cache.dec_lock_ref(req.last_node)
Lianmin Zheng's avatar
Lianmin Zheng committed
384
                    available_size += delta
385
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
386
                else:
387
                    # Add this request to the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
388
389
                    can_run_list.append(req)
                    new_batch_total_tokens += (
390
                        req.extend_input_len + req.max_new_tokens()
Lianmin Zheng's avatar
Lianmin Zheng committed
391
                    )
392
                    new_batch_input_tokens += req.extend_input_len
393
394
            else:
                break
Lianmin Zheng's avatar
Lianmin Zheng committed
395
396
397
        if len(can_run_list) == 0:
            return None

398
        # Print stats
Lianmin Zheng's avatar
Lianmin Zheng committed
399
        if self.tp_rank == 0:
400
401
402
            running_req = (
                0 if self.running_batch is None else len(self.running_batch.reqs)
            )
Cody Yu's avatar
Cody Yu committed
403
            hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
404
405
406
            self.tree_cache_metrics["total"] += (
                hit_tokens + new_batch_input_tokens
            ) / 10**9
Cody Yu's avatar
Cody Yu committed
407
408
409
410
            self.tree_cache_metrics["hit"] += hit_tokens / 10**9
            tree_cache_hit_rate = (
                self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
411
            logger.info(
412
413
414
415
416
417
418
                f"[gpu_id={self.gpu_id}] Prefil batch. "
                f"#new-seq: {len(can_run_list)}, "
                f"#new-token: {new_batch_input_tokens}, "
                f"#cached-token: {hit_tokens}, "
                f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
                f"#running-req: {running_req}, "
                f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
Cody Yu's avatar
Cody Yu committed
419
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
420
            # logger.debug(
421
422
423
424
            #    f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
            #    f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
            #    f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
            #    f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
Liangsheng Yin's avatar
Liangsheng Yin committed
425
            # )
Lianmin Zheng's avatar
Lianmin Zheng committed
426

427
        # Return the new batch
428
        new_batch = Batch.init_new(
Lianmin Zheng's avatar
Lianmin Zheng committed
429
430
431
432
433
434
435
436
437
438
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
        )
        self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
        return new_batch

    def forward_fill_batch(self, batch: Batch):
        # Build batch tensors
439
440
441
442
        batch.prepare_for_extend(
            self.model_config.vocab_size, self.int_token_logit_bias
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
443
444
        if batch.extend_num_tokens != 0:
            # Forward
445
            logits, (
Liangsheng Yin's avatar
Liangsheng Yin committed
446
                prefill_token_logprobs,
447
                normalized_prompt_logprobs,
Liangsheng Yin's avatar
Liangsheng Yin committed
448
449
                prefill_top_logprobs,
                decode_top_logprobs,
450
                last_logprobs,
451
            ) = self.model_runner.forward(batch, ForwardMode.EXTEND)
Liangsheng Yin's avatar
Liangsheng Yin committed
452
            if prefill_token_logprobs is not None:
453
454
                prefill_token_logprobs = prefill_token_logprobs.tolist()
                normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
455

Cody Yu's avatar
Cody Yu committed
456
            next_token_ids, _ = batch.sample(logits)
457
458
459

            # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
            if last_logprobs is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
460
461
462
463
                last_token_logprobs = last_logprobs[
                    torch.arange(len(batch.reqs), device=next_token_ids.device),
                    next_token_ids,
                ].tolist()
464
465

            next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
466
467
        else:
            next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
Cody Yu's avatar
Cody Yu committed
468
469

        # Check finish condition
470
        pt = 0
471
        for i, req in enumerate(batch.reqs):
472
            req.completion_tokens_wo_jump_forward += 1
Liangsheng Yin's avatar
Liangsheng Yin committed
473
            req.output_ids.append(next_token_ids[i])
474
            req.check_finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
475

476
            if req.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
477
478
479
480
481
482
483
484
485
486
                if req.normalized_prompt_logprob is None:
                    req.normalized_prompt_logprob = normalized_prompt_logprobs[i]

                if req.prefill_token_logprobs is None:
                    # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
                    req.prefill_token_logprobs = list(
                        zip(
                            prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
                            req.input_ids[-req.extend_input_len + 1 :],
                        )
Liangsheng Yin's avatar
Liangsheng Yin committed
487
                    )
Liangsheng Yin's avatar
Liangsheng Yin committed
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
                    if req.logprob_start_len == 0:
                        req.prefill_token_logprobs = [
                            (None, req.input_ids[0])
                        ] + req.prefill_token_logprobs

                if req.last_update_decode_tokens != 0:
                    req.decode_token_logprobs.extend(
                        list(
                            zip(
                                prefill_token_logprobs[
                                    pt
                                    + req.extend_input_len
                                    - req.last_update_decode_tokens : pt
                                    + req.extend_input_len
                                    - 1
                                ],
                                req.input_ids[-req.last_update_decode_tokens + 1 :],
                            )
                        )
                    )

                req.decode_token_logprobs.append(
Liangsheng Yin's avatar
Liangsheng Yin committed
510
                    (last_token_logprobs[i], next_token_ids[i])
Liangsheng Yin's avatar
Liangsheng Yin committed
511
                )
512
513

            if req.top_logprobs_num > 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
514
515
516
517
518
519
520
521
522
523
                if req.prefill_top_logprobs is None:
                    req.prefill_top_logprobs = prefill_top_logprobs[i]
                    if req.logprob_start_len == 0:
                        req.prefill_top_logprobs = [None] + req.prefill_top_logprobs

                if req.last_update_decode_tokens != 0:
                    req.decode_top_logprobs.extend(
                        prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
                    )
                req.decode_top_logprobs.append(decode_top_logprobs[i])
524
525

            pt += req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
526
527
528

        self.handle_finished_requests(batch)

Liangsheng Yin's avatar
Liangsheng Yin committed
529
    def cache_filled_batch(self, batch: Batch):
530
        req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
Liangsheng Yin's avatar
Liangsheng Yin committed
531
532
        for i, req in enumerate(batch.reqs):
            new_prefix_indices, new_last_node = self.tree_cache.cache_req(
Liangsheng Yin's avatar
Liangsheng Yin committed
533
                token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
Liangsheng Yin's avatar
Liangsheng Yin committed
534
535
536
537
538
539
540
                last_uncached_pos=len(req.prefix_indices),
                req_pool_idx=req_pool_indices_cpu[i],
                del_in_memory_pool=False,
                old_last_node=req.last_node,
            )
            req.prefix_indices, req.last_node = new_prefix_indices, new_last_node

Lianmin Zheng's avatar
Lianmin Zheng committed
541
    def forward_decode_batch(self, batch: Batch):
542
543
544
        # check if decode out of memory
        if not batch.check_decode_mem():
            old_ratio = self.new_token_ratio
Liangsheng Yin's avatar
Liangsheng Yin committed
545
            self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
546
547
548
549
550
551
552
553
554
555

            retracted_reqs = batch.retract_decode()
            logger.info(
                "decode out of memory happened, "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.forward_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
Liangsheng Yin's avatar
Liangsheng Yin committed
556
                self.new_token_ratio - self.new_token_ratio_decay,
557
558
559
                self.min_new_token_ratio,
            )

560
        if not self.disable_regex_jump_forward:
Liangsheng Yin's avatar
Liangsheng Yin committed
561
            # check for jump-forward
Liangsheng Yin's avatar
Liangsheng Yin committed
562
            jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
563

Liangsheng Yin's avatar
Liangsheng Yin committed
564
            self.forward_queue.extend(jump_forward_reqs)
Liangsheng Yin's avatar
Liangsheng Yin committed
565
566
567
            if batch.is_empty():
                return

Lianmin Zheng's avatar
Lianmin Zheng committed
568
        # Update batch tensors
569
        self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
570
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
571
572

        # Forward
573
574
575
576
        logits, (
            _,
            _,
            _,
577
            decode_top_logprobs,
578
579
            last_logprobs,
        ) = self.model_runner.forward(batch, ForwardMode.DECODE)
Cody Yu's avatar
Cody Yu committed
580
        next_token_ids, _ = batch.sample(logits)
581
        next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
582

Cody Yu's avatar
Cody Yu committed
583
584
        # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
        if last_logprobs is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
585
            new_token_logprobs = last_logprobs[
586
                torch.arange(len(batch.reqs)), next_token_ids
587
            ].tolist()
Cody Yu's avatar
Cody Yu committed
588
589

        # Check finish condition
590
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
591
            req.completion_tokens_wo_jump_forward += 1
Liangsheng Yin's avatar
Liangsheng Yin committed
592
            req.output_ids.append(next_token_id)
Cody Yu's avatar
Cody Yu committed
593
594
            req.check_finished()

595
            if req.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
596
                req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
597
598

            if req.top_logprobs_num > 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
599
                req.decode_top_logprobs.append(decode_top_logprobs[i])
Lianmin Zheng's avatar
Lianmin Zheng committed
600
601
602
603
604

        self.handle_finished_requests(batch)

    def handle_finished_requests(self, batch: Batch):
        output_rids = []
Liangsheng Yin's avatar
Liangsheng Yin committed
605
606
607
        decoded_texts = []
        surr_output_ids = []
        read_output_ids = []
Lianmin Zheng's avatar
Lianmin Zheng committed
608
        output_skip_special_tokens = []
609
        output_spaces_between_special_tokens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
610
        output_meta_info = []
611
        output_finished_reason: List[BaseFinishReason] = []
Lianmin Zheng's avatar
Lianmin Zheng committed
612
613
614
        finished_indices = []
        unfinished_indices = []
        for i, req in enumerate(batch.reqs):
615
            if req.finished():
Lianmin Zheng's avatar
Lianmin Zheng committed
616
617
618
619
                finished_indices.append(i)
            else:
                unfinished_indices.append(i)

620
            if req.finished() or (
621
622
623
624
625
626
627
                (
                    req.stream
                    and (
                        self.decode_forward_ct % self.stream_interval == 0
                        or len(req.output_ids) == 1
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
628
629
            ):
                output_rids.append(req.rid)
Liangsheng Yin's avatar
Liangsheng Yin committed
630
631
632
633
                decoded_texts.append(req.decoded_text)
                surr_ids, read_ids, _ = req.init_detokenize_incrementally()
                surr_output_ids.append(surr_ids)
                read_output_ids.append(read_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
634
635
636
                output_skip_special_tokens.append(
                    req.sampling_params.skip_special_tokens
                )
637
638
639
                output_spaces_between_special_tokens.append(
                    req.sampling_params.spaces_between_special_tokens
                )
640

Lianmin Zheng's avatar
Lianmin Zheng committed
641
                meta_info = {
Liangsheng Yin's avatar
Liangsheng Yin committed
642
                    "prompt_tokens": len(req.origin_input_ids),
Liangsheng Yin's avatar
Liangsheng Yin committed
643
                    "completion_tokens": len(req.output_ids),
Liangsheng Yin's avatar
Liangsheng Yin committed
644
                    "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
645
                    "finish_reason": str(req.finished_reason),
Lianmin Zheng's avatar
Lianmin Zheng committed
646
                }
647
                if req.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
648
649
650
651
652
653
654
655
656
657
658
659
660
                    (
                        meta_info["prefill_token_logprobs"],
                        meta_info["decode_token_logprobs"],
                        meta_info["prefill_top_logprobs"],
                        meta_info["decode_top_logprobs"],
                        meta_info["normalized_prompt_logprob"],
                    ) = (
                        req.prefill_token_logprobs,
                        req.decode_token_logprobs,
                        req.prefill_top_logprobs,
                        req.decode_top_logprobs,
                        req.normalized_prompt_logprob,
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
661
                output_meta_info.append(meta_info)
662
                output_finished_reason.append(req.finished_reason)
Lianmin Zheng's avatar
Lianmin Zheng committed
663
664
665
666
667
668

        # Send to detokenizer
        if output_rids:
            self.out_pyobjs.append(
                BatchTokenIDOut(
                    output_rids,
Liangsheng Yin's avatar
Liangsheng Yin committed
669
670
671
                    decoded_texts,
                    surr_output_ids,
                    read_output_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
672
                    output_skip_special_tokens,
673
                    output_spaces_between_special_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
674
                    output_meta_info,
675
                    output_finished_reason,
Lianmin Zheng's avatar
Lianmin Zheng committed
676
677
678
679
680
681
                )
            )

        # Remove finished reqs
        if finished_indices:
            # Update radix cache
682
            req_pool_indices_cpu = batch.req_pool_indices.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
683
684
            for i in finished_indices:
                req = batch.reqs[i]
Liangsheng Yin's avatar
Liangsheng Yin committed
685
                self.tree_cache.cache_req(
Liangsheng Yin's avatar
Liangsheng Yin committed
686
                    token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
Liangsheng Yin's avatar
Liangsheng Yin committed
687
688
                    last_uncached_pos=len(req.prefix_indices),
                    req_pool_idx=req_pool_indices_cpu[i],
689
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
690

Liangsheng Yin's avatar
Liangsheng Yin committed
691
                self.tree_cache.dec_lock_ref(req.last_node)
Lianmin Zheng's avatar
Lianmin Zheng committed
692
693
694
695
696
697
698

            # Update batch tensors
            if unfinished_indices:
                batch.filter_batch(unfinished_indices)
            else:
                batch.reqs = []

699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
    def flush_cache(self):
        if len(self.forward_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
            self.regex_fsm_cache.reset()
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
        else:
            warnings.warn(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.forward_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )

    def abort_request(self, recv_req):
718
        # Delete requests in the waiting queue
719
720
721
722
723
724
725
726
727
        to_del = None
        for i, req in enumerate(self.forward_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.forward_queue[to_del]

728
729
730
731
        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
                if req.rid == recv_req.rid:
732
                    req.finished_reason = FINISH_ABORT()
733
734
                    break

Lianmin Zheng's avatar
Lianmin Zheng committed
735

736
737
class ModelTpService(rpyc.Service):
    exposed_ModelTpServer = ModelTpServer
738
739


740
class ModelTpClient:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
741
    def __init__(
742
743
744
745
746
        self,
        gpu_ids: List[int],
        server_args: ServerArgs,
        model_port_args: ModelPortArgs,
        model_overide_args,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
747
    ):
748
749
        server_args, model_port_args = obtain(server_args), obtain(model_port_args)
        self.tp_size = server_args.tp_size
Lianmin Zheng's avatar
Lianmin Zheng committed
750

751
        if self.tp_size * server_args.dp_size == 1:
Lianmin Zheng's avatar
Lianmin Zheng committed
752
            # Init model
753
754
755
756
757
758
759
            assert len(gpu_ids) == 1
            self.model_server = ModelTpService().exposed_ModelTpServer(
                0,
                gpu_ids[0],
                server_args,
                model_port_args,
                model_overide_args,
760
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
761
762
763
764
765
766
767
768
769
770

            # Wrap functions
            def async_wrap(f):
                async def _func(*args, **kwargs):
                    return f(*args, **kwargs)

                return _func

            self.step = async_wrap(self.model_server.exposed_step)
        else:
771
            with ThreadPoolExecutor(self.tp_size) as executor:
Lianmin Zheng's avatar
Lianmin Zheng committed
772
                # Launch model processes
773
774
775
776
777
                rets = executor.map(
                    lambda args: start_rpyc_process(*args),
                    [(ModelTpService, p) for p in model_port_args.model_tp_ports],
                )
                self.model_services = [x[0] for x in rets]
Lianmin Zheng's avatar
Lianmin Zheng committed
778
779
780
781
                self.procs = [x[1] for x in rets]

                # Init model
                def init_model(i):
782
783
784
785
786
787
                    return self.model_services[i].ModelTpServer(
                        gpu_ids[i],
                        i,
                        server_args,
                        model_port_args,
                        model_overide_args,
788
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
789

790
                self.model_servers = executor.map(init_model, range(self.tp_size))
Lianmin Zheng's avatar
Lianmin Zheng committed
791
792
793
794
795
796
797
798
799
800
801
802

            # Wrap functions
            def async_wrap(func_name):
                fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]

                async def _func(*args, **kwargs):
                    tasks = [f(*args, **kwargs) for f in fs]
                    await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
                    return obtain(tasks[0].value)

                return _func

Liangsheng Yin's avatar
Liangsheng Yin committed
803
            self.step = async_wrap("step")