model_rpc.py 30 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
import asyncio
import logging
import multiprocessing
import time
Lianmin Zheng's avatar
Lianmin Zheng committed
5
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
6
from concurrent.futures import ThreadPoolExecutor
Yuanhan Zhang's avatar
Yuanhan Zhang committed
7
from typing import Any, Dict, List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
8
9
10
11
12

import rpyc
import torch
from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer
Yuanhan Zhang's avatar
Yuanhan Zhang committed
13

14
15
16
17
try:
    from vllm.logger import _default_handler as vllm_default_logger
except ImportError:
    from vllm.logger import logger as vllm_default_logger
Liangsheng Yin's avatar
Liangsheng Yin committed
18

Lianmin Zheng's avatar
Lianmin Zheng committed
19
from sglang.srt.constrained.fsm_cache import FSMCache
20
from sglang.srt.constrained.jump_forward import JumpForwardCache
Lianmin Zheng's avatar
Lianmin Zheng committed
21
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
Liangsheng Yin's avatar
Liangsheng Yin committed
22
from sglang.srt.managers.io_struct import (
23
    AbortReq,
Liangsheng Yin's avatar
Liangsheng Yin committed
24
25
    BatchTokenIDOut,
    FlushCacheReq,
26
    TokenizedGenerateReqInput,
Liangsheng Yin's avatar
Liangsheng Yin committed
27
)
Liangsheng Yin's avatar
Liangsheng Yin committed
28
from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req
Lianmin Zheng's avatar
Lianmin Zheng committed
29
30
31
32
33
34
35
36
37
38
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
    get_int_token_logit_bias,
    is_multimodal_model,
    set_random_seed,
)
39
40
from sglang.utils import get_exception_traceback

Lianmin Zheng's avatar
Lianmin Zheng committed
41
logger = logging.getLogger("model_rpc")
42
vllm_default_logger.setLevel(logging.WARN)
Lianmin Zheng's avatar
Lianmin Zheng committed
43
logging.getLogger("vllm.utils").setLevel(logging.WARN)
44
logging.getLogger("vllm.selector").setLevel(logging.WARN)
Lianmin Zheng's avatar
Lianmin Zheng committed
45
46


47
48
class ModelRpcServer:
    def __init__(
Lianmin Zheng's avatar
Lianmin Zheng committed
49
50
51
52
        self,
        tp_rank: int,
        server_args: ServerArgs,
        port_args: PortArgs,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
53
        model_overide_args: Optional[dict] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
56
57
58
59
60
    ):
        server_args, port_args = [obtain(x) for x in [server_args, port_args]]

        # Copy arguments
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
        self.schedule_heuristic = server_args.schedule_heuristic
61
        self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
Lianmin Zheng's avatar
Lianmin Zheng committed
62
63
64

        # Init model and tokenizer
        self.model_config = ModelConfig(
Liangsheng Yin's avatar
Liangsheng Yin committed
65
66
67
            server_args.model_path,
            server_args.trust_remote_code,
            context_length=server_args.context_length,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
68
            model_overide_args=model_overide_args,
Lianmin Zheng's avatar
Lianmin Zheng committed
69
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
70

71
        # For model end global settings
Liangsheng Yin's avatar
Liangsheng Yin committed
72
73
74
75
76
        server_args_dict = {
            "enable_flashinfer": server_args.enable_flashinfer,
            "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
        }

Lianmin Zheng's avatar
Lianmin Zheng committed
77
        self.model_runner = ModelRunner(
Liangsheng Yin's avatar
Liangsheng Yin committed
78
79
80
81
82
83
84
            model_config=self.model_config,
            mem_fraction_static=server_args.mem_fraction_static,
            tp_rank=tp_rank,
            tp_size=server_args.tp_size,
            nccl_port=port_args.nccl_port,
            load_format=server_args.load_format,
            trust_remote_code=server_args.trust_remote_code,
Liangsheng Yin's avatar
Liangsheng Yin committed
85
            server_args_dict=server_args_dict,
Lianmin Zheng's avatar
Lianmin Zheng committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        )
        if is_multimodal_model(server_args.model_path):
            self.processor = get_processor(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
            self.tokenizer = self.processor.tokenizer
        else:
            self.tokenizer = get_tokenizer(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
        self.max_total_num_token = self.model_runner.max_total_num_token
        self.max_num_running_seq = self.max_total_num_token // 2
        self.max_prefill_num_token = max(
103
            self.model_config.context_len,
104
105
106
107
108
            (
                self.max_total_num_token // 6
                if server_args.max_prefill_num_token is None
                else server_args.max_prefill_num_token
            ),
Lianmin Zheng's avatar
Lianmin Zheng committed
109
110
111
112
113
        )
        self.int_token_logit_bias = torch.tensor(
            get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
        )
        set_random_seed(server_args.random_seed)
114
115

        # Print info
Lianmin Zheng's avatar
Lianmin Zheng committed
116
117
118
119
120
121
        logger.info(
            f"Rank {self.tp_rank}: "
            f"max_total_num_token={self.max_total_num_token}, "
            f"max_prefill_num_token={self.max_prefill_num_token}, "
            f"context_len={self.model_config.context_len}, "
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
122
123
        if self.tp_rank == 0:
            logger.info(f"server_args: {server_args.print_mode_args()}")
Lianmin Zheng's avatar
Lianmin Zheng committed
124
125

        # Init cache
Liangsheng Yin's avatar
Liangsheng Yin committed
126
127
128
129
130
        self.tree_cache = RadixCache(
            req_to_token_pool=self.model_runner.req_to_token_pool,
            token_to_kv_pool=self.model_runner.token_to_kv_pool,
            disable=server_args.disable_radix_cache,
        )
Cody Yu's avatar
Cody Yu committed
131
        self.tree_cache_metrics = {"total": 0, "hit": 0}
Lianmin Zheng's avatar
Lianmin Zheng committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        self.scheduler = Scheduler(
            self.schedule_heuristic,
            self.max_num_running_seq,
            self.max_prefill_num_token,
            self.max_total_num_token,
            self.tree_cache,
        )
        self.req_to_token_pool = self.model_runner.req_to_token_pool
        self.token_to_kv_pool = self.model_runner.token_to_kv_pool

        # Init running status
        self.forward_queue: List[Req] = []
        self.running_batch: Batch = None
        self.out_pyobjs = []
        self.decode_forward_ct = 0
147
        self.stream_interval = server_args.stream_interval
Lianmin Zheng's avatar
Lianmin Zheng committed
148
149
        self.num_generated_tokens = 0
        self.last_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
150
151

        # Init the FSM cache for constrained generation
152
153
154
155
156
157
158
        self.regex_fsm_cache = FSMCache(
            server_args.tokenizer_path,
            {
                "tokenizer_mode": server_args.tokenizer_mode,
                "trust_remote_code": server_args.trust_remote_code,
            },
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
159
        self.jump_forward_cache = JumpForwardCache()
Lianmin Zheng's avatar
Lianmin Zheng committed
160

161
162
163
164
165
        # Init new token estimation
        self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
        self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
        self.new_token_ratio_step = (0.0001, 0.05)  # (down, up)

Lianmin Zheng's avatar
Lianmin Zheng committed
166
167
168
169
170
171
172
173
174
    def exposed_step(self, recv_reqs):
        if self.tp_size != 1:
            recv_reqs = obtain(recv_reqs)

        try:
            # Recv requests
            for recv_req in recv_reqs:
                if isinstance(recv_req, TokenizedGenerateReqInput):
                    self.handle_generate_request(recv_req)
Liangsheng Yin's avatar
Liangsheng Yin committed
175
176
                elif isinstance(recv_req, FlushCacheReq):
                    self.flush_cache()
177
178
                elif isinstance(recv_req, AbortReq):
                    self.abort_request(recv_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
                else:
                    raise ValueError(f"Invalid request: {recv_req}")

            # Forward
            self.forward_step()
        except Exception:
            logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())

        # Return results
        ret = self.out_pyobjs
        self.out_pyobjs = []
        return ret

    @torch.inference_mode()
    def forward_step(self):
        new_batch = self.get_new_fill_batch()

        if new_batch is not None:
197
            # Run a new fill batch
Lianmin Zheng's avatar
Lianmin Zheng committed
198
            self.forward_fill_batch(new_batch)
Liangsheng Yin's avatar
Liangsheng Yin committed
199
200
            self.cache_filled_batch(new_batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
201
202
203
204
205
206
207
208
209
210
            if not new_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = new_batch
                else:
                    self.running_batch.merge(new_batch)
        else:
            # Run decode batch
            if self.running_batch is not None:
                # Run a few decode batches continuously for reducing overhead
                for _ in range(10):
Lianmin Zheng's avatar
Lianmin Zheng committed
211
                    self.num_generated_tokens += len(self.running_batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
212
213
                    self.forward_decode_batch(self.running_batch)

214
215
                    # Print stats
                    if self.tp_rank == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
216
217
218
219
220
                        if self.decode_forward_ct % 40 == 0:
                            num_used = self.max_total_num_token - (
                                self.token_to_kv_pool.available_size()
                                + self.tree_cache.evictable_size()
                            )
Liangsheng Yin's avatar
Liangsheng Yin committed
221
222
223
                            throuhgput = self.num_generated_tokens / (
                                time.time() - self.last_stats_tic
                            )
Lianmin Zheng's avatar
Lianmin Zheng committed
224
225
                            self.num_generated_tokens = 0
                            self.last_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
226
227
228
229
                            logger.info(
                                f"#running-req: {len(self.running_batch.reqs)}, "
                                f"#token: {num_used}, "
                                f"token usage: {num_used / self.max_total_num_token:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
230
                                f"gen throughput (token/s): {throuhgput:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
231
232
                                f"#queue-req: {len(self.forward_queue)}"
                            )
233
234
235
236
237
238
239

                    if self.running_batch.is_empty():
                        self.running_batch = None
                        break

                    if self.out_pyobjs and self.running_batch.reqs[0].stream:
                        break
240
            else:
241
                # Check the available size
242
243
244
245
246
                available_size = (
                    self.token_to_kv_pool.available_size()
                    + self.tree_cache.evictable_size()
                )
                if available_size != self.max_total_num_token:
Ying Sheng's avatar
Ying Sheng committed
247
                    warnings.warn(
248
249
250
251
                        "Warning: "
                        f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
                        "KV cache pool leak detected!"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
252
253
254
255
256

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
257
        req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
258
259
        req.pixel_values = recv_req.pixel_values
        if req.pixel_values is not None:
260
            req.pad_value = [
Lianmin Zheng's avatar
Lianmin Zheng committed
261
262
263
264
265
                (recv_req.image_hash) % self.model_config.vocab_size,
                (recv_req.image_hash >> 16) % self.model_config.vocab_size,
                (recv_req.image_hash >> 32) % self.model_config.vocab_size,
                (recv_req.image_hash >> 64) % self.model_config.vocab_size,
            ]
Lianmin Zheng's avatar
Lianmin Zheng committed
266
            req.image_size = recv_req.image_size
Lianmin Zheng's avatar
Lianmin Zheng committed
267
            req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
268
                req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
Lianmin Zheng's avatar
Lianmin Zheng committed
269
270
            )
        req.sampling_params = recv_req.sampling_params
271
272
        req.return_logprob = recv_req.return_logprob
        req.logprob_start_len = recv_req.logprob_start_len
Liangsheng Yin's avatar
Liangsheng Yin committed
273
        req.top_logprobs_num = recv_req.top_logprobs_num
Lianmin Zheng's avatar
Lianmin Zheng committed
274
275
276
        req.stream = recv_req.stream
        req.tokenizer = self.tokenizer

277
278
        # Init regex fsm
        if req.sampling_params.regex is not None:
Cody Yu's avatar
Cody Yu committed
279
            req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
280
            if not self.disable_regex_jump_forward:
Liangsheng Yin's avatar
Liangsheng Yin committed
281
                req.jump_forward_map = self.jump_forward_cache.query(
Liangsheng Yin's avatar
Liangsheng Yin committed
282
283
                    req.sampling_params.regex
                )
284

285
        # Truncate prompts that are too long
Lianmin Zheng's avatar
Lianmin Zheng committed
286
287
288
289
        req.input_ids = req.input_ids[: self.model_config.context_len - 1]
        req.sampling_params.max_new_tokens = min(
            req.sampling_params.max_new_tokens,
            self.model_config.context_len - 1 - len(req.input_ids),
290
            self.max_total_num_token - 128 - len(req.input_ids),
Lianmin Zheng's avatar
Lianmin Zheng committed
291
292
293
294
295
296
297
298
299
300
        )
        self.forward_queue.append(req)

    def get_new_fill_batch(self):
        if (
            self.running_batch is not None
            and len(self.running_batch.reqs) > self.max_num_running_seq
        ):
            return None

301
        # Compute matched prefix length
Lianmin Zheng's avatar
Lianmin Zheng committed
302
303
        for req in self.forward_queue:
            prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
304
305
306
            if req.return_logprob:
                prefix_indices = prefix_indices[: req.logprob_start_len]
            req.extend_input_len = len(req.input_ids) - len(prefix_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
            req.prefix_indices = prefix_indices
            req.last_node = last_node

        # Get priority queue
        self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)

        # Add requests if there is available space
        can_run_list = []
        new_batch_total_tokens = 0
        new_batch_input_tokens = 0

        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if self.running_batch:
            available_size -= sum(
                [
324
                    (r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
325
326
327
328
329
                    for r in self.running_batch.reqs
                ]
            )

        for req in self.forward_queue:
330
            if req.return_logprob:
Lianmin Zheng's avatar
Lianmin Zheng committed
331
                # Need at least two tokens to compute normalized logprob
332
333
334
                if req.extend_input_len < 2:
                    delta = 2 - req.extend_input_len
                    req.extend_input_len += delta
Lianmin Zheng's avatar
Lianmin Zheng committed
335
336
337
                    req.prefix_indices = req.prefix_indices[:-delta]
                    if req.image_offset is not None:
                        req.image_offset += delta
338
            if req.extend_input_len == 0 and req.max_new_tokens() > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
339
                # Need at least one token to compute logits
340
                req.extend_input_len = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
341
342
343
344
345
                req.prefix_indices = req.prefix_indices[:-1]
                if req.image_offset is not None:
                    req.image_offset += 1

            if (
346
                req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
347
                < available_size
348
                and req.extend_input_len + new_batch_input_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
349
350
                < self.max_prefill_num_token
            ):
Liangsheng Yin's avatar
Liangsheng Yin committed
351
                delta = self.tree_cache.inc_lock_ref(req.last_node)
Lianmin Zheng's avatar
Lianmin Zheng committed
352
353
354
                available_size += delta

                if not (
355
                    req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
356
357
                    < available_size
                ):
Liangsheng Yin's avatar
Liangsheng Yin committed
358
359
                    # Undo locking
                    delta = self.tree_cache.dec_lock_ref(req.last_node)
Lianmin Zheng's avatar
Lianmin Zheng committed
360
                    available_size += delta
361
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
362
                else:
363
                    # Add this request to the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
364
365
                    can_run_list.append(req)
                    new_batch_total_tokens += (
366
                        req.extend_input_len + req.max_new_tokens()
Lianmin Zheng's avatar
Lianmin Zheng committed
367
                    )
368
                    new_batch_input_tokens += req.extend_input_len
369
370
            else:
                break
Lianmin Zheng's avatar
Lianmin Zheng committed
371
372
373
        if len(can_run_list) == 0:
            return None

374
        # Print stats
Lianmin Zheng's avatar
Lianmin Zheng committed
375
        if self.tp_rank == 0:
376
377
378
            running_req = (
                0 if self.running_batch is None else len(self.running_batch.reqs)
            )
Cody Yu's avatar
Cody Yu committed
379
            hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
380
381
382
            self.tree_cache_metrics["total"] += (
                hit_tokens + new_batch_input_tokens
            ) / 10**9
Cody Yu's avatar
Cody Yu committed
383
384
385
386
            self.tree_cache_metrics["hit"] += hit_tokens / 10**9
            tree_cache_hit_rate = (
                self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
387
388
            logger.info(
                f"new fill batch. #seq: {len(can_run_list)}. "
Cody Yu's avatar
Cody Yu committed
389
                f"#cached_token: {hit_tokens}. "
Lianmin Zheng's avatar
Lianmin Zheng committed
390
391
                f"#new_token: {new_batch_input_tokens}. "
                f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
Cody Yu's avatar
Cody Yu committed
392
393
394
                f"#running_req: {running_req}. "
                f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
395
            # logger.debug(
396
397
398
399
            #    f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
            #    f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
            #    f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
            #    f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
Liangsheng Yin's avatar
Liangsheng Yin committed
400
            # )
Lianmin Zheng's avatar
Lianmin Zheng committed
401

402
        # Return the new batch
403
        new_batch = Batch.init_new(
Lianmin Zheng's avatar
Lianmin Zheng committed
404
405
406
407
408
409
410
411
412
413
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
        )
        self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
        return new_batch

    def forward_fill_batch(self, batch: Batch):
        # Build batch tensors
414
415
416
417
        batch.prepare_for_extend(
            self.model_config.vocab_size, self.int_token_logit_bias
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
418
419
        if batch.extend_num_tokens != 0:
            # Forward
420
            logits, (
Liangsheng Yin's avatar
Liangsheng Yin committed
421
                prefill_token_logprobs,
422
                normalized_prompt_logprobs,
Liangsheng Yin's avatar
Liangsheng Yin committed
423
424
                prefill_top_logprobs,
                decode_top_logprobs,
425
                last_logprobs,
426
            ) = self.model_runner.forward(batch, ForwardMode.EXTEND)
Liangsheng Yin's avatar
Liangsheng Yin committed
427
            if prefill_token_logprobs is not None:
428
429
                prefill_token_logprobs = prefill_token_logprobs.tolist()
                normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
430

Cody Yu's avatar
Cody Yu committed
431
            next_token_ids, _ = batch.sample(logits)
432
433
434

            # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
            if last_logprobs is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
435
436
437
438
                last_token_logprobs = last_logprobs[
                    torch.arange(len(batch.reqs), device=next_token_ids.device),
                    next_token_ids,
                ].tolist()
439
440

            next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
441
442
        else:
            next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
Cody Yu's avatar
Cody Yu committed
443
444

        # Check finish condition
445
        pt = 0
446
        for i, req in enumerate(batch.reqs):
447
            req.completion_tokens_wo_jump_forward += 1
448
449
            req.output_ids = [next_token_ids[i]]
            req.check_finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
450

451
452
453
            if req.return_logprob:
                req.normalized_prompt_logprob = normalized_prompt_logprobs[i]

Liangsheng Yin's avatar
Liangsheng Yin committed
454
455
456
457
458
459
460
                # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
                req.prefill_token_logprobs = list(
                    zip(
                        prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
                        req.input_ids[-req.extend_input_len + 1 :],
                    )
                )
461
                if req.logprob_start_len == 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
462
463
464
465
466
467
                    req.prefill_token_logprobs = [
                        (None, req.input_ids[0])
                    ] + req.prefill_token_logprobs
                req.decode_token_logprobs = [
                    (last_token_logprobs[i], next_token_ids[i])
                ]
468
469

            if req.top_logprobs_num > 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
470
471
472
473
                req.prefill_top_logprobs = prefill_top_logprobs[i]
                if req.logprob_start_len == 0:
                    req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
                req.decode_top_logprobs = [decode_top_logprobs[i]]
474
475

            pt += req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
476
477
478

        self.handle_finished_requests(batch)

Liangsheng Yin's avatar
Liangsheng Yin committed
479
    def cache_filled_batch(self, batch: Batch):
480
        req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
Liangsheng Yin's avatar
Liangsheng Yin committed
481
482
483
484
485
486
487
488
489
490
        for i, req in enumerate(batch.reqs):
            new_prefix_indices, new_last_node = self.tree_cache.cache_req(
                token_ids=tuple(req.input_ids + req.output_ids)[:-1],
                last_uncached_pos=len(req.prefix_indices),
                req_pool_idx=req_pool_indices_cpu[i],
                del_in_memory_pool=False,
                old_last_node=req.last_node,
            )
            req.prefix_indices, req.last_node = new_prefix_indices, new_last_node

Lianmin Zheng's avatar
Lianmin Zheng committed
491
    def forward_decode_batch(self, batch: Batch):
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
        # check if decode out of memory
        if not batch.check_decode_mem():
            old_ratio = self.new_token_ratio
            self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)

            retracted_reqs = batch.retract_decode()
            logger.info(
                "decode out of memory happened, "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.forward_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
                self.new_token_ratio - self.new_token_ratio_step[0],
                self.min_new_token_ratio,
            )

510
        if not self.disable_regex_jump_forward:
Liangsheng Yin's avatar
Liangsheng Yin committed
511
512
            # check for jump-forward
            jump_forward_reqs = batch.check_for_jump_forward()
513

Liangsheng Yin's avatar
Liangsheng Yin committed
514
515
            # check for image jump-forward
            for req in jump_forward_reqs:
516
517
518
519
520
521
522
523
524
525
526
                if req.pixel_values is not None:
                    (
                        req.input_ids,
                        req.image_offset,
                    ) = self.model_runner.model.pad_input_ids(
                        req.input_ids,
                        req.pad_value,
                        req.pixel_values.shape,
                        req.image_size,
                    )

Liangsheng Yin's avatar
Liangsheng Yin committed
527
            self.forward_queue.extend(jump_forward_reqs)
Liangsheng Yin's avatar
Liangsheng Yin committed
528
529
530
            if batch.is_empty():
                return

Lianmin Zheng's avatar
Lianmin Zheng committed
531
        # Update batch tensors
532
        self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
533
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
534
535

        # Forward
536
537
538
539
        logits, (
            _,
            _,
            _,
540
            decode_top_logprobs,
541
542
            last_logprobs,
        ) = self.model_runner.forward(batch, ForwardMode.DECODE)
Cody Yu's avatar
Cody Yu committed
543
        next_token_ids, _ = batch.sample(logits)
544
        next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
545

Cody Yu's avatar
Cody Yu committed
546
547
        # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
        if last_logprobs is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
548
            new_token_logprobs = last_logprobs[
549
                torch.arange(len(batch.reqs)), next_token_ids
550
            ].tolist()
Cody Yu's avatar
Cody Yu committed
551
552

        # Check finish condition
553
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
554
            req.completion_tokens_wo_jump_forward += 1
Liangsheng Yin's avatar
Liangsheng Yin committed
555
            req.output_ids.append(next_token_id)
Cody Yu's avatar
Cody Yu committed
556
557
            req.check_finished()

558
            if req.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
559
                req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
560
561

            if req.top_logprobs_num > 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
562
                req.decode_top_logprobs.append(decode_top_logprobs[i])
Lianmin Zheng's avatar
Lianmin Zheng committed
563
564
565
566
567
568

        self.handle_finished_requests(batch)

    def handle_finished_requests(self, batch: Batch):
        output_rids = []
        output_tokens = []
Liangsheng Yin's avatar
Liangsheng Yin committed
569
        output_and_jump_forward_strs = []
Lianmin Zheng's avatar
Lianmin Zheng committed
570
571
        output_hit_stop_str = []
        output_skip_special_tokens = []
572
        output_spaces_between_special_tokens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
573
574
575
576
577
578
579
580
581
582
583
        output_meta_info = []
        output_finished = []
        finished_indices = []
        unfinished_indices = []
        for i, req in enumerate(batch.reqs):
            if req.finished:
                finished_indices.append(i)
            else:
                unfinished_indices.append(i)

            if req.finished or (
584
585
586
587
588
589
590
                (
                    req.stream
                    and (
                        self.decode_forward_ct % self.stream_interval == 0
                        or len(req.output_ids) == 1
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
591
592
593
            ):
                output_rids.append(req.rid)
                output_tokens.append(req.output_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
594
                output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
Lianmin Zheng's avatar
Lianmin Zheng committed
595
596
597
598
                output_hit_stop_str.append(req.hit_stop_str)
                output_skip_special_tokens.append(
                    req.sampling_params.skip_special_tokens
                )
599
600
601
                output_spaces_between_special_tokens.append(
                    req.sampling_params.spaces_between_special_tokens
                )
602

Lianmin Zheng's avatar
Lianmin Zheng committed
603
                meta_info = {
604
                    "prompt_tokens": req.prompt_tokens,
605
606
                    "completion_tokens": len(req.input_ids)
                    + len(req.output_ids)
607
                    - req.prompt_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
608
                    "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
Lianmin Zheng's avatar
Lianmin Zheng committed
609
                    "finish_reason": FinishReason.to_str(req.finish_reason),
Lianmin Zheng's avatar
Lianmin Zheng committed
610
                    "hit_stop_str": req.hit_stop_str,
Lianmin Zheng's avatar
Lianmin Zheng committed
611
                }
612
                if req.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
613
614
615
616
617
618
619
620
621
622
623
624
625
                    (
                        meta_info["prefill_token_logprobs"],
                        meta_info["decode_token_logprobs"],
                        meta_info["prefill_top_logprobs"],
                        meta_info["decode_top_logprobs"],
                        meta_info["normalized_prompt_logprob"],
                    ) = (
                        req.prefill_token_logprobs,
                        req.decode_token_logprobs,
                        req.prefill_top_logprobs,
                        req.decode_top_logprobs,
                        req.normalized_prompt_logprob,
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
626
627
628
629
630
631
632
633
634
                output_meta_info.append(meta_info)
                output_finished.append(req.finished)

        # Send to detokenizer
        if output_rids:
            self.out_pyobjs.append(
                BatchTokenIDOut(
                    output_rids,
                    output_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
635
                    output_and_jump_forward_strs,
Lianmin Zheng's avatar
Lianmin Zheng committed
636
637
                    output_hit_stop_str,
                    output_skip_special_tokens,
638
                    output_spaces_between_special_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
639
640
641
642
643
644
645
646
                    output_meta_info,
                    output_finished,
                )
            )

        # Remove finished reqs
        if finished_indices:
            # Update radix cache
647
            req_pool_indices_cpu = batch.req_pool_indices.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
648
649
            for i in finished_indices:
                req = batch.reqs[i]
Liangsheng Yin's avatar
Liangsheng Yin committed
650
651
652
653
                self.tree_cache.cache_req(
                    token_ids=tuple(req.input_ids + req.output_ids)[:-1],
                    last_uncached_pos=len(req.prefix_indices),
                    req_pool_idx=req_pool_indices_cpu[i],
654
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
655

Liangsheng Yin's avatar
Liangsheng Yin committed
656
                self.tree_cache.dec_lock_ref(req.last_node)
Lianmin Zheng's avatar
Lianmin Zheng committed
657
658
659
660
661
662
663

            # Update batch tensors
            if unfinished_indices:
                batch.filter_batch(unfinished_indices)
            else:
                batch.reqs = []

664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
    def flush_cache(self):
        if len(self.forward_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
            self.regex_fsm_cache.reset()
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
        else:
            warnings.warn(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.forward_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )

    def abort_request(self, recv_req):
683
        # Delete requests in the waiting queue
684
685
686
687
688
689
690
691
692
        to_del = None
        for i, req in enumerate(self.forward_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.forward_queue[to_del]

693
694
695
696
697
698
699
700
        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
                if req.rid == recv_req.rid:
                    req.finished = True
                    req.finish_reason = FinishReason.ABORT
                    break

Lianmin Zheng's avatar
Lianmin Zheng committed
701

702
703
704
705
class ModelRpcService(rpyc.Service):
    exposed_ModelRpcServer = ModelRpcServer


Lianmin Zheng's avatar
Lianmin Zheng committed
706
class ModelRpcClient:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
707
708
709
    def __init__(
        self, server_args: ServerArgs, port_args: PortArgs, model_overide_args
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
710
711
712
713
        tp_size = server_args.tp_size

        if tp_size == 1:
            # Init model
714
            self.model_server = ModelRpcService().exposed_ModelRpcServer(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
715
                0, server_args, port_args, model_overide_args
716
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
717
718
719
720
721
722
723
724
725
726
727
728
729

            # Wrap functions
            def async_wrap(f):
                async def _func(*args, **kwargs):
                    return f(*args, **kwargs)

                return _func

            self.step = async_wrap(self.model_server.exposed_step)
        else:
            with ThreadPoolExecutor(tp_size) as executor:
                # Launch model processes
                rets = executor.map(start_model_process, port_args.model_rpc_ports)
730
                self.remote_services = [x[0] for x in rets]
Lianmin Zheng's avatar
Lianmin Zheng committed
731
732
733
734
                self.procs = [x[1] for x in rets]

                # Init model
                def init_model(i):
735
                    return self.remote_services[i].ModelRpcServer(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
736
                        i, server_args, port_args, model_overide_args
737
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
738

739
                self.model_servers = executor.map(init_model, range(tp_size))
Lianmin Zheng's avatar
Lianmin Zheng committed
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754

            # Wrap functions
            def async_wrap(func_name):
                fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]

                async def _func(*args, **kwargs):
                    tasks = [f(*args, **kwargs) for f in fs]
                    await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
                    return obtain(tasks[0].value)

                return _func

            self.step = async_wrap("step")


755
756
def _init_service(port):
    t = ThreadedServer(
757
        ModelRpcService(),
758
        port=port,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
759
760
761
762
763
        protocol_config={
            "allow_public_attrs": True,
            "allow_pickle": True,
            "sync_request_timeout": 1800,
        },
764
765
766
    )
    t.start()

Lianmin Zheng's avatar
Lianmin Zheng committed
767

768
def start_model_process(port):
Lianmin Zheng's avatar
Lianmin Zheng committed
769
770
771
772
773
774
775
776
777
778
    proc = multiprocessing.Process(target=_init_service, args=(port,))
    proc.start()
    time.sleep(1)

    repeat_count = 0
    while repeat_count < 20:
        try:
            con = rpyc.connect(
                "localhost",
                port,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
779
780
781
782
783
                config={
                    "allow_public_attrs": True,
                    "allow_pickle": True,
                    "sync_request_timeout": 1800,
                },
Lianmin Zheng's avatar
Lianmin Zheng committed
784
785
786
787
788
789
790
791
792
793
            )
            break
        except ConnectionRefusedError:
            time.sleep(1)
        repeat_count += 1
    if repeat_count == 20:
        raise RuntimeError("init rpc env error!")

    assert proc.is_alive()
    return con.root, proc