model_rpc.py 27.8 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
import asyncio
import logging
import multiprocessing
import time
Lianmin Zheng's avatar
Lianmin Zheng committed
5
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
6
from concurrent.futures import ThreadPoolExecutor
Cody Yu's avatar
Cody Yu committed
7
from typing import List
Lianmin Zheng's avatar
Lianmin Zheng committed
8
9
10
11
12

import rpyc
import torch
from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer
13
14
15
16
try:
    from vllm.logger import _default_handler as vllm_default_logger
except ImportError:
    from vllm.logger import logger as vllm_default_logger
Liangsheng Yin's avatar
Liangsheng Yin committed
17

Lianmin Zheng's avatar
Lianmin Zheng committed
18
from sglang.srt.constrained.fsm_cache import FSMCache
19
from sglang.srt.constrained.jump_forward import JumpForwardCache
Lianmin Zheng's avatar
Lianmin Zheng committed
20
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
Liangsheng Yin's avatar
Liangsheng Yin committed
21
22
23
from sglang.srt.managers.io_struct import (
    BatchTokenIDOut,
    FlushCacheReq,
24
    TokenizedGenerateReqInput,
Liangsheng Yin's avatar
Liangsheng Yin committed
25
)
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
28
29
30
31
32
33
34
35
36
37
38
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
    get_exception_traceback,
    get_int_token_logit_bias,
    is_multimodal_model,
    set_random_seed,
)

39

Lianmin Zheng's avatar
Lianmin Zheng committed
40
logger = logging.getLogger("model_rpc")
41
vllm_default_logger.setLevel(logging.WARN)
Lianmin Zheng's avatar
Lianmin Zheng committed
42
logging.getLogger("vllm.utils").setLevel(logging.WARN)
Lianmin Zheng's avatar
Lianmin Zheng committed
43
44


45
46
class ModelRpcServer:
    def __init__(
Lianmin Zheng's avatar
Lianmin Zheng committed
47
48
49
50
51
52
53
54
55
56
57
        self,
        tp_rank: int,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
        server_args, port_args = [obtain(x) for x in [server_args, port_args]]

        # Copy arguments
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
        self.schedule_heuristic = server_args.schedule_heuristic
58
        self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
Lianmin Zheng's avatar
Lianmin Zheng committed
59
60
61

        # Init model and tokenizer
        self.model_config = ModelConfig(
Liangsheng Yin's avatar
Liangsheng Yin committed
62
63
64
            server_args.model_path,
            server_args.trust_remote_code,
            context_length=server_args.context_length,
Lianmin Zheng's avatar
Lianmin Zheng committed
65
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
66

67
        # For model end global settings
Liangsheng Yin's avatar
Liangsheng Yin committed
68
69
70
71
72
        server_args_dict = {
            "enable_flashinfer": server_args.enable_flashinfer,
            "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
        }

Lianmin Zheng's avatar
Lianmin Zheng committed
73
        self.model_runner = ModelRunner(
Liangsheng Yin's avatar
Liangsheng Yin committed
74
75
76
77
78
79
80
            model_config=self.model_config,
            mem_fraction_static=server_args.mem_fraction_static,
            tp_rank=tp_rank,
            tp_size=server_args.tp_size,
            nccl_port=port_args.nccl_port,
            load_format=server_args.load_format,
            trust_remote_code=server_args.trust_remote_code,
Liangsheng Yin's avatar
Liangsheng Yin committed
81
            server_args_dict=server_args_dict,
Lianmin Zheng's avatar
Lianmin Zheng committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        )
        if is_multimodal_model(server_args.model_path):
            self.processor = get_processor(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
            self.tokenizer = self.processor.tokenizer
        else:
            self.tokenizer = get_tokenizer(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
        self.max_total_num_token = self.model_runner.max_total_num_token
        self.max_num_running_seq = self.max_total_num_token // 2
        self.max_prefill_num_token = max(
99
            self.model_config.context_len,
100
101
102
103
104
            (
                self.max_total_num_token // 6
                if server_args.max_prefill_num_token is None
                else server_args.max_prefill_num_token
            ),
Lianmin Zheng's avatar
Lianmin Zheng committed
105
106
107
108
109
110
111
112
113
114
115
        )
        self.int_token_logit_bias = torch.tensor(
            get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
        )
        set_random_seed(server_args.random_seed)
        logger.info(
            f"Rank {self.tp_rank}: "
            f"max_total_num_token={self.max_total_num_token}, "
            f"max_prefill_num_token={self.max_prefill_num_token}, "
            f"context_len={self.model_config.context_len}, "
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
116
117
        if self.tp_rank == 0:
            logger.info(f"server_args: {server_args.print_mode_args()}")
Lianmin Zheng's avatar
Lianmin Zheng committed
118
119

        # Init cache
120
        self.tree_cache = RadixCache(disable=server_args.disable_radix_cache)
Cody Yu's avatar
Cody Yu committed
121
        self.tree_cache_metrics = {"total": 0, "hit": 0}
Lianmin Zheng's avatar
Lianmin Zheng committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        self.scheduler = Scheduler(
            self.schedule_heuristic,
            self.max_num_running_seq,
            self.max_prefill_num_token,
            self.max_total_num_token,
            self.tree_cache,
        )
        self.req_to_token_pool = self.model_runner.req_to_token_pool
        self.token_to_kv_pool = self.model_runner.token_to_kv_pool

        # Init running status
        self.forward_queue: List[Req] = []
        self.running_batch: Batch = None
        self.out_pyobjs = []
        self.decode_forward_ct = 0
137
        self.stream_interval = server_args.stream_interval
Lianmin Zheng's avatar
Lianmin Zheng committed
138
139

        # Init the FSM cache for constrained generation
140
141
142
143
144
145
146
        self.regex_fsm_cache = FSMCache(
            server_args.tokenizer_path,
            {
                "tokenizer_mode": server_args.tokenizer_mode,
                "trust_remote_code": server_args.trust_remote_code,
            },
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
147
        self.jump_forward_cache = JumpForwardCache()
Lianmin Zheng's avatar
Lianmin Zheng committed
148

149
150
151
152
153
        # Init new token estimation
        self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
        self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
        self.new_token_ratio_step = (0.0001, 0.05)  # (down, up)

Liangsheng Yin's avatar
Liangsheng Yin committed
154
155
156
157
158
    def flush_cache(self):
        if len(self.forward_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
Cody Yu's avatar
Cody Yu committed
159
160
            self.tree_cache_metrics = {"total": 0, "hit": 0}
            self.regex_fsm_cache.reset()
Liangsheng Yin's avatar
Liangsheng Yin committed
161
162
163
164
165
166
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
        else:
            warnings.warn(
167
                f"Cache not flushed because there are pending requests. "
Liangsheng Yin's avatar
Liangsheng Yin committed
168
169
170
171
                f"#queue-req: {len(self.forward_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
172
173
174
175
176
177
178
179
180
    def exposed_step(self, recv_reqs):
        if self.tp_size != 1:
            recv_reqs = obtain(recv_reqs)

        try:
            # Recv requests
            for recv_req in recv_reqs:
                if isinstance(recv_req, TokenizedGenerateReqInput):
                    self.handle_generate_request(recv_req)
Liangsheng Yin's avatar
Liangsheng Yin committed
181
182
                elif isinstance(recv_req, FlushCacheReq):
                    self.flush_cache()
Lianmin Zheng's avatar
Lianmin Zheng committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
                else:
                    raise ValueError(f"Invalid request: {recv_req}")

            # Forward
            self.forward_step()
        except Exception:
            logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())

        # Return results
        ret = self.out_pyobjs
        self.out_pyobjs = []
        return ret

    @torch.inference_mode()
    def forward_step(self):
        new_batch = self.get_new_fill_batch()

        if new_batch is not None:
            # Run new fill batch
            self.forward_fill_batch(new_batch)

            if not new_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = new_batch
                else:
                    self.running_batch.merge(new_batch)
        else:
            # Run decode batch
            if self.running_batch is not None:
                # Run a few decode batches continuously for reducing overhead
                for _ in range(10):
                    self.forward_decode_batch(self.running_batch)

                    if self.running_batch.is_empty():
                        self.running_batch = None
                        break
219
220
221

                    if self.out_pyobjs and self.running_batch.reqs[0].stream:
                        break
Lianmin Zheng's avatar
Lianmin Zheng committed
222
223
224
225
226
227
228
229
230
231
232
233
234

                    if self.running_batch is not None and self.tp_rank == 0:
                        if self.decode_forward_ct % 40 == 0:
                            num_used = self.max_total_num_token - (
                                self.token_to_kv_pool.available_size()
                                + self.tree_cache.evictable_size()
                            )
                            logger.info(
                                f"#running-req: {len(self.running_batch.reqs)}, "
                                f"#token: {num_used}, "
                                f"token usage: {num_used / self.max_total_num_token:.2f}, "
                                f"#queue-req: {len(self.forward_queue)}"
                            )
235
236
237
238
239
240
241
            else:
                # check the available size
                available_size = (
                    self.token_to_kv_pool.available_size()
                    + self.tree_cache.evictable_size()
                )
                if available_size != self.max_total_num_token:
Ying Sheng's avatar
Ying Sheng committed
242
                    warnings.warn(
243
244
245
246
                        "Warning: "
                        f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
                        "KV cache pool leak detected!"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
247
248
249
250
251

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
252
        req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
253
254
        req.pixel_values = recv_req.pixel_values
        if req.pixel_values is not None:
255
            req.pad_value = [
Lianmin Zheng's avatar
Lianmin Zheng committed
256
257
258
259
260
                (recv_req.image_hash) % self.model_config.vocab_size,
                (recv_req.image_hash >> 16) % self.model_config.vocab_size,
                (recv_req.image_hash >> 32) % self.model_config.vocab_size,
                (recv_req.image_hash >> 64) % self.model_config.vocab_size,
            ]
Lianmin Zheng's avatar
Lianmin Zheng committed
261
            req.image_size = recv_req.image_size
Lianmin Zheng's avatar
Lianmin Zheng committed
262
            req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
263
                req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
Lianmin Zheng's avatar
Lianmin Zheng committed
264
265
            )
        req.sampling_params = recv_req.sampling_params
266
267
        req.return_logprob = recv_req.return_logprob
        req.logprob_start_len = recv_req.logprob_start_len
Liangsheng Yin's avatar
Liangsheng Yin committed
268
        req.top_logprobs_num = recv_req.top_logprobs_num
Lianmin Zheng's avatar
Lianmin Zheng committed
269
270
271
        req.stream = recv_req.stream
        req.tokenizer = self.tokenizer

272
273
        # Init regex fsm
        if req.sampling_params.regex is not None:
Cody Yu's avatar
Cody Yu committed
274
            req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
275
            if not self.disable_regex_jump_forward:
Liangsheng Yin's avatar
Liangsheng Yin committed
276
                req.jump_forward_map = self.jump_forward_cache.query(
Liangsheng Yin's avatar
Liangsheng Yin committed
277
278
                    req.sampling_params.regex
                )
279

Lianmin Zheng's avatar
Lianmin Zheng committed
280
281
282
283
284
        # Truncate long prompts
        req.input_ids = req.input_ids[: self.model_config.context_len - 1]
        req.sampling_params.max_new_tokens = min(
            req.sampling_params.max_new_tokens,
            self.model_config.context_len - 1 - len(req.input_ids),
285
            self.max_total_num_token - 128 - len(req.input_ids),
Lianmin Zheng's avatar
Lianmin Zheng committed
286
287
288
289
290
291
292
293
294
295
296
297
        )
        self.forward_queue.append(req)

    def get_new_fill_batch(self):
        if (
            self.running_batch is not None
            and len(self.running_batch.reqs) > self.max_num_running_seq
        ):
            return None

        for req in self.forward_queue:
            prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
298
299
300
            if req.return_logprob:
                prefix_indices = prefix_indices[: req.logprob_start_len]
            req.extend_input_len = len(req.input_ids) - len(prefix_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
            req.prefix_indices = prefix_indices
            req.last_node = last_node

        # Get priority queue
        self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)

        # Add requests if there is available space
        can_run_list = []
        new_batch_total_tokens = 0
        new_batch_input_tokens = 0

        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if self.running_batch:
            available_size -= sum(
                [
318
                    (r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
319
320
321
322
323
                    for r in self.running_batch.reqs
                ]
            )

        for req in self.forward_queue:
324
            if req.return_logprob:
Lianmin Zheng's avatar
Lianmin Zheng committed
325
                # Need at least two tokens to compute normalized logprob
326
327
328
                if req.extend_input_len < 2:
                    delta = 2 - req.extend_input_len
                    req.extend_input_len += delta
Lianmin Zheng's avatar
Lianmin Zheng committed
329
330
331
                    req.prefix_indices = req.prefix_indices[:-delta]
                    if req.image_offset is not None:
                        req.image_offset += delta
332
            if req.extend_input_len == 0 and req.max_new_tokens() > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
333
                # Need at least one token to compute logits
334
                req.extend_input_len = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
335
336
337
338
339
                req.prefix_indices = req.prefix_indices[:-1]
                if req.image_offset is not None:
                    req.image_offset += 1

            if (
340
                req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
341
                < available_size
342
                and req.extend_input_len + new_batch_input_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
343
344
345
346
347
348
                < self.max_prefill_num_token
            ):
                delta = self.tree_cache.inc_ref_counter(req.last_node)
                available_size += delta

                if not (
349
                    req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
350
351
                    < available_size
                ):
352
                    # Undo the insertion
Lianmin Zheng's avatar
Lianmin Zheng committed
353
354
                    delta = self.tree_cache.dec_ref_counter(req.last_node)
                    available_size += delta
355
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
356
                else:
357
                    # Add this request to the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
358
359
360
                    self.token_to_kv_pool.add_refs(req.prefix_indices)
                    can_run_list.append(req)
                    new_batch_total_tokens += (
361
                        req.extend_input_len + req.max_new_tokens()
Lianmin Zheng's avatar
Lianmin Zheng committed
362
                    )
363
                    new_batch_input_tokens += req.extend_input_len
364
365
            else:
                break
Lianmin Zheng's avatar
Lianmin Zheng committed
366
367
368
369
        if len(can_run_list) == 0:
            return None

        if self.tp_rank == 0:
370
371
372
            running_req = (
                0 if self.running_batch is None else len(self.running_batch.reqs)
            )
Cody Yu's avatar
Cody Yu committed
373
            hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
374
375
376
            self.tree_cache_metrics["total"] += (
                hit_tokens + new_batch_input_tokens
            ) / 10**9
Cody Yu's avatar
Cody Yu committed
377
378
379
380
            self.tree_cache_metrics["hit"] += hit_tokens / 10**9
            tree_cache_hit_rate = (
                self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
381
382
            logger.info(
                f"new fill batch. #seq: {len(can_run_list)}. "
Cody Yu's avatar
Cody Yu committed
383
                f"#cached_token: {hit_tokens}. "
Lianmin Zheng's avatar
Lianmin Zheng committed
384
385
                f"#new_token: {new_batch_input_tokens}. "
                f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
Cody Yu's avatar
Cody Yu committed
386
387
388
                f"#running_req: {running_req}. "
                f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
            )
389
390
391
392
393
394
            #logger.debug(
            #    f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
            #    f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
            #    f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
            #    f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
            #)
Lianmin Zheng's avatar
Lianmin Zheng committed
395

396
        new_batch = Batch.init_new(
Lianmin Zheng's avatar
Lianmin Zheng committed
397
398
399
400
401
402
403
404
405
406
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
        )
        self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
        return new_batch

    def forward_fill_batch(self, batch: Batch):
        # Build batch tensors
407
408
409
410
        batch.prepare_for_extend(
            self.model_config.vocab_size, self.int_token_logit_bias
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
411
412
        if batch.extend_num_tokens != 0:
            # Forward
413
            logits, (
Liangsheng Yin's avatar
Liangsheng Yin committed
414
                prefill_token_logprobs,
415
                normalized_prompt_logprobs,
Liangsheng Yin's avatar
Liangsheng Yin committed
416
417
                prefill_top_logprobs,
                decode_top_logprobs,
418
                last_logprobs,
419
            ) = self.model_runner.forward(batch, ForwardMode.EXTEND)
Liangsheng Yin's avatar
Liangsheng Yin committed
420
            if prefill_token_logprobs is not None:
421
422
                prefill_token_logprobs = prefill_token_logprobs.tolist()
                normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
423

Cody Yu's avatar
Cody Yu committed
424
            next_token_ids, _ = batch.sample(logits)
425
426
427
428

            # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
            if last_logprobs is not None:
                last_token_logprobs = (
Lianmin Zheng's avatar
Lianmin Zheng committed
429
430
431
                    last_logprobs[
                        torch.arange(len(batch.reqs), device=next_token_ids.device),
                        next_token_ids].tolist()
432
433
434
                )

            next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
435
436
        else:
            next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
Cody Yu's avatar
Cody Yu committed
437
438

        # Check finish condition
439
        pt = 0
440
        for i, req in enumerate(batch.reqs):
441
            req.completion_tokens_wo_jump_forward += 1
442
443
            req.output_ids = [next_token_ids[i]]
            req.check_finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
444

445
446
447
            if req.return_logprob:
                req.normalized_prompt_logprob = normalized_prompt_logprobs[i]

Liangsheng Yin's avatar
Liangsheng Yin committed
448
449
450
451
452
453
454
                # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
                req.prefill_token_logprobs = list(
                    zip(
                        prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
                        req.input_ids[-req.extend_input_len + 1 :],
                    )
                )
455
                if req.logprob_start_len == 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
456
457
458
459
460
461
                    req.prefill_token_logprobs = [
                        (None, req.input_ids[0])
                    ] + req.prefill_token_logprobs
                req.decode_token_logprobs = [
                    (last_token_logprobs[i], next_token_ids[i])
                ]
462
463

            if req.top_logprobs_num > 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
464
465
466
467
                req.prefill_top_logprobs = prefill_top_logprobs[i]
                if req.logprob_start_len == 0:
                    req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
                req.decode_top_logprobs = [decode_top_logprobs[i]]
468
469

            pt += req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
470
471
472
473

        self.handle_finished_requests(batch)

    def forward_decode_batch(self, batch: Batch):
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
        # check if decode out of memory
        if not batch.check_decode_mem():
            old_ratio = self.new_token_ratio
            self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)

            retracted_reqs = batch.retract_decode()
            logger.info(
                "decode out of memory happened, "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.forward_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
                self.new_token_ratio - self.new_token_ratio_step[0],
                self.min_new_token_ratio,
            )

492
        if not self.disable_regex_jump_forward:
Liangsheng Yin's avatar
Liangsheng Yin committed
493
494
            # check for jump-forward
            jump_forward_reqs = batch.check_for_jump_forward()
495

Liangsheng Yin's avatar
Liangsheng Yin committed
496
497
            # check for image jump-forward
            for req in jump_forward_reqs:
498
499
500
501
502
503
504
505
506
507
508
                if req.pixel_values is not None:
                    (
                        req.input_ids,
                        req.image_offset,
                    ) = self.model_runner.model.pad_input_ids(
                        req.input_ids,
                        req.pad_value,
                        req.pixel_values.shape,
                        req.image_size,
                    )

Liangsheng Yin's avatar
Liangsheng Yin committed
509
            self.forward_queue.extend(jump_forward_reqs)
Liangsheng Yin's avatar
Liangsheng Yin committed
510
511
512
            if batch.is_empty():
                return

Lianmin Zheng's avatar
Lianmin Zheng committed
513
        # Update batch tensors
514
        self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
515
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
516
517

        # Forward
518
519
520
521
        logits, (
            _,
            _,
            _,
522
            decode_top_logprobs,
523
524
            last_logprobs,
        ) = self.model_runner.forward(batch, ForwardMode.DECODE)
Cody Yu's avatar
Cody Yu committed
525
        next_token_ids, _ = batch.sample(logits)
526
        next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
527

Cody Yu's avatar
Cody Yu committed
528
529
        # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
        if last_logprobs is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
530
            new_token_logprobs = last_logprobs[
531
                torch.arange(len(batch.reqs)), next_token_ids
532
            ].tolist()
Cody Yu's avatar
Cody Yu committed
533
534

        # Check finish condition
535
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
536
            req.completion_tokens_wo_jump_forward += 1
Liangsheng Yin's avatar
Liangsheng Yin committed
537
            req.output_ids.append(next_token_id)
Cody Yu's avatar
Cody Yu committed
538
539
            req.check_finished()

540
            if req.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
541
                req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
542
543

            if req.top_logprobs_num > 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
544
                req.decode_top_logprobs.append(decode_top_logprobs[i])
Lianmin Zheng's avatar
Lianmin Zheng committed
545
546
547
548
549
550

        self.handle_finished_requests(batch)

    def handle_finished_requests(self, batch: Batch):
        output_rids = []
        output_tokens = []
Liangsheng Yin's avatar
Liangsheng Yin committed
551
        output_and_jump_forward_strs = []
Lianmin Zheng's avatar
Lianmin Zheng committed
552
553
        output_hit_stop_str = []
        output_skip_special_tokens = []
554
        output_spaces_between_special_tokens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
555
556
557
558
559
560
561
562
563
564
565
        output_meta_info = []
        output_finished = []
        finished_indices = []
        unfinished_indices = []
        for i, req in enumerate(batch.reqs):
            if req.finished:
                finished_indices.append(i)
            else:
                unfinished_indices.append(i)

            if req.finished or (
566
567
568
569
570
571
572
                (
                    req.stream
                    and (
                        self.decode_forward_ct % self.stream_interval == 0
                        or len(req.output_ids) == 1
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
573
574
575
            ):
                output_rids.append(req.rid)
                output_tokens.append(req.output_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
576
                output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
Lianmin Zheng's avatar
Lianmin Zheng committed
577
578
579
580
                output_hit_stop_str.append(req.hit_stop_str)
                output_skip_special_tokens.append(
                    req.sampling_params.skip_special_tokens
                )
581
582
583
                output_spaces_between_special_tokens.append(
                    req.sampling_params.spaces_between_special_tokens
                )
584

Lianmin Zheng's avatar
Lianmin Zheng committed
585
                meta_info = {
586
                    "prompt_tokens": req.prompt_tokens,
587
588
                    "completion_tokens": len(req.input_ids)
                    + len(req.output_ids)
589
                    - req.prompt_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
590
                    "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
591
                    "finish_reason": str(req.finish_reason),  # FIXME: convert to the correct string
Lianmin Zheng's avatar
Lianmin Zheng committed
592
                    "hit_stop_str": req.hit_stop_str,
Lianmin Zheng's avatar
Lianmin Zheng committed
593
                }
594
                if req.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
595
596
597
598
599
600
601
602
603
604
605
606
607
                    (
                        meta_info["prefill_token_logprobs"],
                        meta_info["decode_token_logprobs"],
                        meta_info["prefill_top_logprobs"],
                        meta_info["decode_top_logprobs"],
                        meta_info["normalized_prompt_logprob"],
                    ) = (
                        req.prefill_token_logprobs,
                        req.decode_token_logprobs,
                        req.prefill_top_logprobs,
                        req.decode_top_logprobs,
                        req.normalized_prompt_logprob,
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
608
609
610
611
612
613
614
615
616
                output_meta_info.append(meta_info)
                output_finished.append(req.finished)

        # Send to detokenizer
        if output_rids:
            self.out_pyobjs.append(
                BatchTokenIDOut(
                    output_rids,
                    output_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
617
                    output_and_jump_forward_strs,
Lianmin Zheng's avatar
Lianmin Zheng committed
618
619
                    output_hit_stop_str,
                    output_skip_special_tokens,
620
                    output_spaces_between_special_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
621
622
623
624
625
626
627
628
                    output_meta_info,
                    output_finished,
                )
            )

        # Remove finished reqs
        if finished_indices:
            # Update radix cache
629
            req_pool_indices_cpu = batch.req_pool_indices.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
630
631
632
633
634
635
            for i in finished_indices:
                req = batch.reqs[i]
                req_pool_idx = req_pool_indices_cpu[i]
                token_ids = tuple(req.input_ids + req.output_ids)
                seq_len = len(token_ids) - 1
                indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
636
637
638
                prefix_len = self.tree_cache.insert(
                    token_ids[:seq_len], indices.clone()
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
639

640
                self.token_to_kv_pool.dec_refs(indices[:prefix_len])
Lianmin Zheng's avatar
Lianmin Zheng committed
641
642
643
644
645
646
647
648
649
650
                self.req_to_token_pool.free(req_pool_idx)
                self.tree_cache.dec_ref_counter(req.last_node)

            # Update batch tensors
            if unfinished_indices:
                batch.filter_batch(unfinished_indices)
            else:
                batch.reqs = []


651
652
653
654
class ModelRpcService(rpyc.Service):
    exposed_ModelRpcServer = ModelRpcServer


Lianmin Zheng's avatar
Lianmin Zheng committed
655
656
657
658
659
660
class ModelRpcClient:
    def __init__(self, server_args: ServerArgs, port_args: PortArgs):
        tp_size = server_args.tp_size

        if tp_size == 1:
            # Init model
661
662
663
            self.model_server = ModelRpcService().exposed_ModelRpcServer(
                0, server_args, port_args
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
664
665
666
667
668
669
670
671
672
673
674
675
676

            # Wrap functions
            def async_wrap(f):
                async def _func(*args, **kwargs):
                    return f(*args, **kwargs)

                return _func

            self.step = async_wrap(self.model_server.exposed_step)
        else:
            with ThreadPoolExecutor(tp_size) as executor:
                # Launch model processes
                rets = executor.map(start_model_process, port_args.model_rpc_ports)
677
                self.remote_services = [x[0] for x in rets]
Lianmin Zheng's avatar
Lianmin Zheng committed
678
679
680
681
                self.procs = [x[1] for x in rets]

                # Init model
                def init_model(i):
682
683
684
                    return self.remote_services[i].ModelRpcServer(
                        i, server_args, port_args
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
685

686
                self.model_servers = executor.map(init_model, range(tp_size))
Lianmin Zheng's avatar
Lianmin Zheng committed
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701

            # Wrap functions
            def async_wrap(func_name):
                fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]

                async def _func(*args, **kwargs):
                    tasks = [f(*args, **kwargs) for f in fs]
                    await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
                    return obtain(tasks[0].value)

                return _func

            self.step = async_wrap("step")


702
703
def _init_service(port):
    t = ThreadedServer(
704
        ModelRpcService(),
705
706
707
708
709
        port=port,
        protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
    )
    t.start()

Lianmin Zheng's avatar
Lianmin Zheng committed
710

711
def start_model_process(port):
Lianmin Zheng's avatar
Lianmin Zheng committed
712
713
714
715
716
717
718
719
720
721
    proc = multiprocessing.Process(target=_init_service, args=(port,))
    proc.start()
    time.sleep(1)

    repeat_count = 0
    while repeat_count < 20:
        try:
            con = rpyc.connect(
                "localhost",
                port,
722
                config={"allow_pickle": True, "sync_request_timeout": 1800},
Lianmin Zheng's avatar
Lianmin Zheng committed
723
724
725
726
727
728
729
730
731
732
            )
            break
        except ConnectionRefusedError:
            time.sleep(1)
        repeat_count += 1
    if repeat_count == 20:
        raise RuntimeError("init rpc env error!")

    assert proc.is_alive()
    return con.root, proc