model_rpc.py 28.2 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
import asyncio
import logging
import multiprocessing
import time
Lianmin Zheng's avatar
Lianmin Zheng committed
5
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
6
from concurrent.futures import ThreadPoolExecutor
Cody Yu's avatar
Cody Yu committed
7
from typing import List
Lianmin Zheng's avatar
Lianmin Zheng committed
8
9
10
11
12

import rpyc
import torch
from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer
13
14
15
16
try:
    from vllm.logger import _default_handler as vllm_default_logger
except ImportError:
    from vllm.logger import logger as vllm_default_logger
Liangsheng Yin's avatar
Liangsheng Yin committed
17

Lianmin Zheng's avatar
Lianmin Zheng committed
18
from sglang.srt.constrained.fsm_cache import FSMCache
19
from sglang.srt.constrained.jump_forward import JumpForwardCache
Lianmin Zheng's avatar
Lianmin Zheng committed
20
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
Liangsheng Yin's avatar
Liangsheng Yin committed
21
22
23
from sglang.srt.managers.io_struct import (
    BatchTokenIDOut,
    FlushCacheReq,
24
    TokenizedGenerateReqInput,
Liangsheng Yin's avatar
Liangsheng Yin committed
25
)
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
28
29
30
31
32
33
34
35
36
37
38
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
    get_exception_traceback,
    get_int_token_logit_bias,
    is_multimodal_model,
    set_random_seed,
)

39

Lianmin Zheng's avatar
Lianmin Zheng committed
40
logger = logging.getLogger("model_rpc")
41
vllm_default_logger.setLevel(logging.WARN)
Lianmin Zheng's avatar
Lianmin Zheng committed
42
logging.getLogger("vllm.utils").setLevel(logging.WARN)
Lianmin Zheng's avatar
Lianmin Zheng committed
43
44


45
46
class ModelRpcServer:
    def __init__(
Lianmin Zheng's avatar
Lianmin Zheng committed
47
48
49
50
51
52
53
54
55
56
57
        self,
        tp_rank: int,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
        server_args, port_args = [obtain(x) for x in [server_args, port_args]]

        # Copy arguments
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
        self.schedule_heuristic = server_args.schedule_heuristic
58
        self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
Lianmin Zheng's avatar
Lianmin Zheng committed
59
60
61

        # Init model and tokenizer
        self.model_config = ModelConfig(
Liangsheng Yin's avatar
Liangsheng Yin committed
62
63
64
            server_args.model_path,
            server_args.trust_remote_code,
            context_length=server_args.context_length,
Lianmin Zheng's avatar
Lianmin Zheng committed
65
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
66

67
        # For model end global settings
Liangsheng Yin's avatar
Liangsheng Yin committed
68
69
70
71
72
        server_args_dict = {
            "enable_flashinfer": server_args.enable_flashinfer,
            "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
        }

Lianmin Zheng's avatar
Lianmin Zheng committed
73
        self.model_runner = ModelRunner(
Liangsheng Yin's avatar
Liangsheng Yin committed
74
75
76
77
78
79
80
            model_config=self.model_config,
            mem_fraction_static=server_args.mem_fraction_static,
            tp_rank=tp_rank,
            tp_size=server_args.tp_size,
            nccl_port=port_args.nccl_port,
            load_format=server_args.load_format,
            trust_remote_code=server_args.trust_remote_code,
Liangsheng Yin's avatar
Liangsheng Yin committed
81
            server_args_dict=server_args_dict,
Lianmin Zheng's avatar
Lianmin Zheng committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        )
        if is_multimodal_model(server_args.model_path):
            self.processor = get_processor(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
            self.tokenizer = self.processor.tokenizer
        else:
            self.tokenizer = get_tokenizer(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
        self.max_total_num_token = self.model_runner.max_total_num_token
        self.max_num_running_seq = self.max_total_num_token // 2
        self.max_prefill_num_token = max(
99
            self.model_config.context_len,
100
101
102
103
104
            (
                self.max_total_num_token // 6
                if server_args.max_prefill_num_token is None
                else server_args.max_prefill_num_token
            ),
Lianmin Zheng's avatar
Lianmin Zheng committed
105
106
107
108
109
110
111
112
113
114
115
        )
        self.int_token_logit_bias = torch.tensor(
            get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
        )
        set_random_seed(server_args.random_seed)
        logger.info(
            f"Rank {self.tp_rank}: "
            f"max_total_num_token={self.max_total_num_token}, "
            f"max_prefill_num_token={self.max_prefill_num_token}, "
            f"context_len={self.model_config.context_len}, "
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
116
117
        if self.tp_rank == 0:
            logger.info(f"server_args: {server_args.print_mode_args()}")
Lianmin Zheng's avatar
Lianmin Zheng committed
118
119

        # Init cache
120
        self.tree_cache = RadixCache(disable=server_args.disable_radix_cache)
Cody Yu's avatar
Cody Yu committed
121
        self.tree_cache_metrics = {"total": 0, "hit": 0}
Lianmin Zheng's avatar
Lianmin Zheng committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        self.scheduler = Scheduler(
            self.schedule_heuristic,
            self.max_num_running_seq,
            self.max_prefill_num_token,
            self.max_total_num_token,
            self.tree_cache,
        )
        self.req_to_token_pool = self.model_runner.req_to_token_pool
        self.token_to_kv_pool = self.model_runner.token_to_kv_pool

        # Init running status
        self.forward_queue: List[Req] = []
        self.running_batch: Batch = None
        self.out_pyobjs = []
        self.decode_forward_ct = 0
137
        self.stream_interval = server_args.stream_interval
Lianmin Zheng's avatar
Lianmin Zheng committed
138
139
        self.num_generated_tokens = 0
        self.last_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
140
141

        # Init the FSM cache for constrained generation
142
143
144
145
146
147
148
        self.regex_fsm_cache = FSMCache(
            server_args.tokenizer_path,
            {
                "tokenizer_mode": server_args.tokenizer_mode,
                "trust_remote_code": server_args.trust_remote_code,
            },
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
149
        self.jump_forward_cache = JumpForwardCache()
Lianmin Zheng's avatar
Lianmin Zheng committed
150

151
152
153
154
155
        # Init new token estimation
        self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
        self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
        self.new_token_ratio_step = (0.0001, 0.05)  # (down, up)

Liangsheng Yin's avatar
Liangsheng Yin committed
156
157
158
159
160
    def flush_cache(self):
        if len(self.forward_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
Cody Yu's avatar
Cody Yu committed
161
162
            self.tree_cache_metrics = {"total": 0, "hit": 0}
            self.regex_fsm_cache.reset()
Liangsheng Yin's avatar
Liangsheng Yin committed
163
164
165
166
167
168
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
        else:
            warnings.warn(
169
                f"Cache not flushed because there are pending requests. "
Liangsheng Yin's avatar
Liangsheng Yin committed
170
171
172
173
                f"#queue-req: {len(self.forward_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
174
175
176
177
178
179
180
181
182
    def exposed_step(self, recv_reqs):
        if self.tp_size != 1:
            recv_reqs = obtain(recv_reqs)

        try:
            # Recv requests
            for recv_req in recv_reqs:
                if isinstance(recv_req, TokenizedGenerateReqInput):
                    self.handle_generate_request(recv_req)
Liangsheng Yin's avatar
Liangsheng Yin committed
183
184
                elif isinstance(recv_req, FlushCacheReq):
                    self.flush_cache()
Lianmin Zheng's avatar
Lianmin Zheng committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
                else:
                    raise ValueError(f"Invalid request: {recv_req}")

            # Forward
            self.forward_step()
        except Exception:
            logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())

        # Return results
        ret = self.out_pyobjs
        self.out_pyobjs = []
        return ret

    @torch.inference_mode()
    def forward_step(self):
        new_batch = self.get_new_fill_batch()

        if new_batch is not None:
            # Run new fill batch
            self.forward_fill_batch(new_batch)

            if not new_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = new_batch
                else:
                    self.running_batch.merge(new_batch)
        else:
            # Run decode batch
            if self.running_batch is not None:
                # Run a few decode batches continuously for reducing overhead
                for _ in range(10):
Lianmin Zheng's avatar
Lianmin Zheng committed
216
                    self.num_generated_tokens += len(self.running_batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
217
218
219
220
221
                    self.forward_decode_batch(self.running_batch)

                    if self.running_batch.is_empty():
                        self.running_batch = None
                        break
222
223
224

                    if self.out_pyobjs and self.running_batch.reqs[0].stream:
                        break
Lianmin Zheng's avatar
Lianmin Zheng committed
225
226
227
228
229
230
231

                    if self.running_batch is not None and self.tp_rank == 0:
                        if self.decode_forward_ct % 40 == 0:
                            num_used = self.max_total_num_token - (
                                self.token_to_kv_pool.available_size()
                                + self.tree_cache.evictable_size()
                            )
Lianmin Zheng's avatar
Lianmin Zheng committed
232
233
234
                            throuhgput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
                            self.num_generated_tokens = 0
                            self.last_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
235
236
237
238
                            logger.info(
                                f"#running-req: {len(self.running_batch.reqs)}, "
                                f"#token: {num_used}, "
                                f"token usage: {num_used / self.max_total_num_token:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
239
                                f"gen throughput (token/s): {throuhgput:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
240
241
                                f"#queue-req: {len(self.forward_queue)}"
                            )
242
243
244
245
246
247
248
            else:
                # check the available size
                available_size = (
                    self.token_to_kv_pool.available_size()
                    + self.tree_cache.evictable_size()
                )
                if available_size != self.max_total_num_token:
Ying Sheng's avatar
Ying Sheng committed
249
                    warnings.warn(
250
251
252
253
                        "Warning: "
                        f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
                        "KV cache pool leak detected!"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
254
255
256
257
258

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
259
        req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
260
261
        req.pixel_values = recv_req.pixel_values
        if req.pixel_values is not None:
262
            req.pad_value = [
Lianmin Zheng's avatar
Lianmin Zheng committed
263
264
265
266
267
                (recv_req.image_hash) % self.model_config.vocab_size,
                (recv_req.image_hash >> 16) % self.model_config.vocab_size,
                (recv_req.image_hash >> 32) % self.model_config.vocab_size,
                (recv_req.image_hash >> 64) % self.model_config.vocab_size,
            ]
Lianmin Zheng's avatar
Lianmin Zheng committed
268
            req.image_size = recv_req.image_size
Lianmin Zheng's avatar
Lianmin Zheng committed
269
            req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
270
                req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
Lianmin Zheng's avatar
Lianmin Zheng committed
271
272
            )
        req.sampling_params = recv_req.sampling_params
273
274
        req.return_logprob = recv_req.return_logprob
        req.logprob_start_len = recv_req.logprob_start_len
Liangsheng Yin's avatar
Liangsheng Yin committed
275
        req.top_logprobs_num = recv_req.top_logprobs_num
Lianmin Zheng's avatar
Lianmin Zheng committed
276
277
278
        req.stream = recv_req.stream
        req.tokenizer = self.tokenizer

279
280
        # Init regex fsm
        if req.sampling_params.regex is not None:
Cody Yu's avatar
Cody Yu committed
281
            req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
282
            if not self.disable_regex_jump_forward:
Liangsheng Yin's avatar
Liangsheng Yin committed
283
                req.jump_forward_map = self.jump_forward_cache.query(
Liangsheng Yin's avatar
Liangsheng Yin committed
284
285
                    req.sampling_params.regex
                )
286

Lianmin Zheng's avatar
Lianmin Zheng committed
287
288
289
290
291
        # Truncate long prompts
        req.input_ids = req.input_ids[: self.model_config.context_len - 1]
        req.sampling_params.max_new_tokens = min(
            req.sampling_params.max_new_tokens,
            self.model_config.context_len - 1 - len(req.input_ids),
292
            self.max_total_num_token - 128 - len(req.input_ids),
Lianmin Zheng's avatar
Lianmin Zheng committed
293
294
295
296
297
298
299
300
301
302
303
304
        )
        self.forward_queue.append(req)

    def get_new_fill_batch(self):
        if (
            self.running_batch is not None
            and len(self.running_batch.reqs) > self.max_num_running_seq
        ):
            return None

        for req in self.forward_queue:
            prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
305
306
307
            if req.return_logprob:
                prefix_indices = prefix_indices[: req.logprob_start_len]
            req.extend_input_len = len(req.input_ids) - len(prefix_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
            req.prefix_indices = prefix_indices
            req.last_node = last_node

        # Get priority queue
        self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)

        # Add requests if there is available space
        can_run_list = []
        new_batch_total_tokens = 0
        new_batch_input_tokens = 0

        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if self.running_batch:
            available_size -= sum(
                [
325
                    (r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
326
327
328
329
330
                    for r in self.running_batch.reqs
                ]
            )

        for req in self.forward_queue:
331
            if req.return_logprob:
Lianmin Zheng's avatar
Lianmin Zheng committed
332
                # Need at least two tokens to compute normalized logprob
333
334
335
                if req.extend_input_len < 2:
                    delta = 2 - req.extend_input_len
                    req.extend_input_len += delta
Lianmin Zheng's avatar
Lianmin Zheng committed
336
337
338
                    req.prefix_indices = req.prefix_indices[:-delta]
                    if req.image_offset is not None:
                        req.image_offset += delta
339
            if req.extend_input_len == 0 and req.max_new_tokens() > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
340
                # Need at least one token to compute logits
341
                req.extend_input_len = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
342
343
344
345
346
                req.prefix_indices = req.prefix_indices[:-1]
                if req.image_offset is not None:
                    req.image_offset += 1

            if (
347
                req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
348
                < available_size
349
                and req.extend_input_len + new_batch_input_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
350
351
352
353
354
355
                < self.max_prefill_num_token
            ):
                delta = self.tree_cache.inc_ref_counter(req.last_node)
                available_size += delta

                if not (
356
                    req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
357
358
                    < available_size
                ):
359
                    # Undo the insertion
Lianmin Zheng's avatar
Lianmin Zheng committed
360
361
                    delta = self.tree_cache.dec_ref_counter(req.last_node)
                    available_size += delta
362
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
363
                else:
364
                    # Add this request to the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
365
366
367
                    self.token_to_kv_pool.add_refs(req.prefix_indices)
                    can_run_list.append(req)
                    new_batch_total_tokens += (
368
                        req.extend_input_len + req.max_new_tokens()
Lianmin Zheng's avatar
Lianmin Zheng committed
369
                    )
370
                    new_batch_input_tokens += req.extend_input_len
371
372
            else:
                break
Lianmin Zheng's avatar
Lianmin Zheng committed
373
374
375
376
        if len(can_run_list) == 0:
            return None

        if self.tp_rank == 0:
377
378
379
            running_req = (
                0 if self.running_batch is None else len(self.running_batch.reqs)
            )
Cody Yu's avatar
Cody Yu committed
380
            hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
381
382
383
            self.tree_cache_metrics["total"] += (
                hit_tokens + new_batch_input_tokens
            ) / 10**9
Cody Yu's avatar
Cody Yu committed
384
385
386
387
            self.tree_cache_metrics["hit"] += hit_tokens / 10**9
            tree_cache_hit_rate = (
                self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
388
389
            logger.info(
                f"new fill batch. #seq: {len(can_run_list)}. "
Cody Yu's avatar
Cody Yu committed
390
                f"#cached_token: {hit_tokens}. "
Lianmin Zheng's avatar
Lianmin Zheng committed
391
392
                f"#new_token: {new_batch_input_tokens}. "
                f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
Cody Yu's avatar
Cody Yu committed
393
394
395
                f"#running_req: {running_req}. "
                f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
            )
396
397
398
399
400
401
            #logger.debug(
            #    f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
            #    f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
            #    f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
            #    f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
            #)
Lianmin Zheng's avatar
Lianmin Zheng committed
402

403
        new_batch = Batch.init_new(
Lianmin Zheng's avatar
Lianmin Zheng committed
404
405
406
407
408
409
410
411
412
413
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
        )
        self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
        return new_batch

    def forward_fill_batch(self, batch: Batch):
        # Build batch tensors
414
415
416
417
        batch.prepare_for_extend(
            self.model_config.vocab_size, self.int_token_logit_bias
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
418
419
        if batch.extend_num_tokens != 0:
            # Forward
420
            logits, (
Liangsheng Yin's avatar
Liangsheng Yin committed
421
                prefill_token_logprobs,
422
                normalized_prompt_logprobs,
Liangsheng Yin's avatar
Liangsheng Yin committed
423
424
                prefill_top_logprobs,
                decode_top_logprobs,
425
                last_logprobs,
426
            ) = self.model_runner.forward(batch, ForwardMode.EXTEND)
Liangsheng Yin's avatar
Liangsheng Yin committed
427
            if prefill_token_logprobs is not None:
428
429
                prefill_token_logprobs = prefill_token_logprobs.tolist()
                normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
430

Cody Yu's avatar
Cody Yu committed
431
            next_token_ids, _ = batch.sample(logits)
432
433
434
435

            # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
            if last_logprobs is not None:
                last_token_logprobs = (
Lianmin Zheng's avatar
Lianmin Zheng committed
436
437
438
                    last_logprobs[
                        torch.arange(len(batch.reqs), device=next_token_ids.device),
                        next_token_ids].tolist()
439
440
441
                )

            next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
442
443
        else:
            next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
Cody Yu's avatar
Cody Yu committed
444
445

        # Check finish condition
446
        pt = 0
447
        for i, req in enumerate(batch.reqs):
448
            req.completion_tokens_wo_jump_forward += 1
449
450
            req.output_ids = [next_token_ids[i]]
            req.check_finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
451

452
453
454
            if req.return_logprob:
                req.normalized_prompt_logprob = normalized_prompt_logprobs[i]

Liangsheng Yin's avatar
Liangsheng Yin committed
455
456
457
458
459
460
461
                # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
                req.prefill_token_logprobs = list(
                    zip(
                        prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
                        req.input_ids[-req.extend_input_len + 1 :],
                    )
                )
462
                if req.logprob_start_len == 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
463
464
465
466
467
468
                    req.prefill_token_logprobs = [
                        (None, req.input_ids[0])
                    ] + req.prefill_token_logprobs
                req.decode_token_logprobs = [
                    (last_token_logprobs[i], next_token_ids[i])
                ]
469
470

            if req.top_logprobs_num > 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
471
472
473
474
                req.prefill_top_logprobs = prefill_top_logprobs[i]
                if req.logprob_start_len == 0:
                    req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
                req.decode_top_logprobs = [decode_top_logprobs[i]]
475
476

            pt += req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
477
478
479
480

        self.handle_finished_requests(batch)

    def forward_decode_batch(self, batch: Batch):
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
        # check if decode out of memory
        if not batch.check_decode_mem():
            old_ratio = self.new_token_ratio
            self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)

            retracted_reqs = batch.retract_decode()
            logger.info(
                "decode out of memory happened, "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.forward_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
                self.new_token_ratio - self.new_token_ratio_step[0],
                self.min_new_token_ratio,
            )

499
        if not self.disable_regex_jump_forward:
Liangsheng Yin's avatar
Liangsheng Yin committed
500
501
            # check for jump-forward
            jump_forward_reqs = batch.check_for_jump_forward()
502

Liangsheng Yin's avatar
Liangsheng Yin committed
503
504
            # check for image jump-forward
            for req in jump_forward_reqs:
505
506
507
508
509
510
511
512
513
514
515
                if req.pixel_values is not None:
                    (
                        req.input_ids,
                        req.image_offset,
                    ) = self.model_runner.model.pad_input_ids(
                        req.input_ids,
                        req.pad_value,
                        req.pixel_values.shape,
                        req.image_size,
                    )

Liangsheng Yin's avatar
Liangsheng Yin committed
516
            self.forward_queue.extend(jump_forward_reqs)
Liangsheng Yin's avatar
Liangsheng Yin committed
517
518
519
            if batch.is_empty():
                return

Lianmin Zheng's avatar
Lianmin Zheng committed
520
        # Update batch tensors
521
        self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
522
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
523
524

        # Forward
525
526
527
528
        logits, (
            _,
            _,
            _,
529
            decode_top_logprobs,
530
531
            last_logprobs,
        ) = self.model_runner.forward(batch, ForwardMode.DECODE)
Cody Yu's avatar
Cody Yu committed
532
        next_token_ids, _ = batch.sample(logits)
533
        next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
534

Cody Yu's avatar
Cody Yu committed
535
536
        # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
        if last_logprobs is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
537
            new_token_logprobs = last_logprobs[
538
                torch.arange(len(batch.reqs)), next_token_ids
539
            ].tolist()
Cody Yu's avatar
Cody Yu committed
540
541

        # Check finish condition
542
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
543
            req.completion_tokens_wo_jump_forward += 1
Liangsheng Yin's avatar
Liangsheng Yin committed
544
            req.output_ids.append(next_token_id)
Cody Yu's avatar
Cody Yu committed
545
546
            req.check_finished()

547
            if req.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
548
                req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
549
550

            if req.top_logprobs_num > 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
551
                req.decode_top_logprobs.append(decode_top_logprobs[i])
Lianmin Zheng's avatar
Lianmin Zheng committed
552
553
554
555
556
557

        self.handle_finished_requests(batch)

    def handle_finished_requests(self, batch: Batch):
        output_rids = []
        output_tokens = []
Liangsheng Yin's avatar
Liangsheng Yin committed
558
        output_and_jump_forward_strs = []
Lianmin Zheng's avatar
Lianmin Zheng committed
559
560
        output_hit_stop_str = []
        output_skip_special_tokens = []
561
        output_spaces_between_special_tokens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
562
563
564
565
566
567
568
569
570
571
572
        output_meta_info = []
        output_finished = []
        finished_indices = []
        unfinished_indices = []
        for i, req in enumerate(batch.reqs):
            if req.finished:
                finished_indices.append(i)
            else:
                unfinished_indices.append(i)

            if req.finished or (
573
574
575
576
577
578
579
                (
                    req.stream
                    and (
                        self.decode_forward_ct % self.stream_interval == 0
                        or len(req.output_ids) == 1
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
580
581
582
            ):
                output_rids.append(req.rid)
                output_tokens.append(req.output_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
583
                output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
Lianmin Zheng's avatar
Lianmin Zheng committed
584
585
586
587
                output_hit_stop_str.append(req.hit_stop_str)
                output_skip_special_tokens.append(
                    req.sampling_params.skip_special_tokens
                )
588
589
590
                output_spaces_between_special_tokens.append(
                    req.sampling_params.spaces_between_special_tokens
                )
591

Lianmin Zheng's avatar
Lianmin Zheng committed
592
                meta_info = {
593
                    "prompt_tokens": req.prompt_tokens,
594
595
                    "completion_tokens": len(req.input_ids)
                    + len(req.output_ids)
596
                    - req.prompt_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
597
                    "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
598
                    "finish_reason": str(req.finish_reason),  # FIXME: convert to the correct string
Lianmin Zheng's avatar
Lianmin Zheng committed
599
                    "hit_stop_str": req.hit_stop_str,
Lianmin Zheng's avatar
Lianmin Zheng committed
600
                }
601
                if req.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
602
603
604
605
606
607
608
609
610
611
612
613
614
                    (
                        meta_info["prefill_token_logprobs"],
                        meta_info["decode_token_logprobs"],
                        meta_info["prefill_top_logprobs"],
                        meta_info["decode_top_logprobs"],
                        meta_info["normalized_prompt_logprob"],
                    ) = (
                        req.prefill_token_logprobs,
                        req.decode_token_logprobs,
                        req.prefill_top_logprobs,
                        req.decode_top_logprobs,
                        req.normalized_prompt_logprob,
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
615
616
617
618
619
620
621
622
623
                output_meta_info.append(meta_info)
                output_finished.append(req.finished)

        # Send to detokenizer
        if output_rids:
            self.out_pyobjs.append(
                BatchTokenIDOut(
                    output_rids,
                    output_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
624
                    output_and_jump_forward_strs,
Lianmin Zheng's avatar
Lianmin Zheng committed
625
626
                    output_hit_stop_str,
                    output_skip_special_tokens,
627
                    output_spaces_between_special_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
628
629
630
631
632
633
634
635
                    output_meta_info,
                    output_finished,
                )
            )

        # Remove finished reqs
        if finished_indices:
            # Update radix cache
636
            req_pool_indices_cpu = batch.req_pool_indices.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
637
638
639
640
641
642
            for i in finished_indices:
                req = batch.reqs[i]
                req_pool_idx = req_pool_indices_cpu[i]
                token_ids = tuple(req.input_ids + req.output_ids)
                seq_len = len(token_ids) - 1
                indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
643
644
645
                prefix_len = self.tree_cache.insert(
                    token_ids[:seq_len], indices.clone()
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
646

647
                self.token_to_kv_pool.dec_refs(indices[:prefix_len])
Lianmin Zheng's avatar
Lianmin Zheng committed
648
649
650
651
652
653
654
655
656
657
                self.req_to_token_pool.free(req_pool_idx)
                self.tree_cache.dec_ref_counter(req.last_node)

            # Update batch tensors
            if unfinished_indices:
                batch.filter_batch(unfinished_indices)
            else:
                batch.reqs = []


658
659
660
661
class ModelRpcService(rpyc.Service):
    exposed_ModelRpcServer = ModelRpcServer


Lianmin Zheng's avatar
Lianmin Zheng committed
662
663
664
665
666
667
class ModelRpcClient:
    def __init__(self, server_args: ServerArgs, port_args: PortArgs):
        tp_size = server_args.tp_size

        if tp_size == 1:
            # Init model
668
669
670
            self.model_server = ModelRpcService().exposed_ModelRpcServer(
                0, server_args, port_args
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
671
672
673
674
675
676
677
678
679
680
681
682
683

            # Wrap functions
            def async_wrap(f):
                async def _func(*args, **kwargs):
                    return f(*args, **kwargs)

                return _func

            self.step = async_wrap(self.model_server.exposed_step)
        else:
            with ThreadPoolExecutor(tp_size) as executor:
                # Launch model processes
                rets = executor.map(start_model_process, port_args.model_rpc_ports)
684
                self.remote_services = [x[0] for x in rets]
Lianmin Zheng's avatar
Lianmin Zheng committed
685
686
687
688
                self.procs = [x[1] for x in rets]

                # Init model
                def init_model(i):
689
690
691
                    return self.remote_services[i].ModelRpcServer(
                        i, server_args, port_args
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
692

693
                self.model_servers = executor.map(init_model, range(tp_size))
Lianmin Zheng's avatar
Lianmin Zheng committed
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708

            # Wrap functions
            def async_wrap(func_name):
                fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]

                async def _func(*args, **kwargs):
                    tasks = [f(*args, **kwargs) for f in fs]
                    await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
                    return obtain(tasks[0].value)

                return _func

            self.step = async_wrap("step")


709
710
def _init_service(port):
    t = ThreadedServer(
711
        ModelRpcService(),
712
713
714
715
716
        port=port,
        protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
    )
    t.start()

Lianmin Zheng's avatar
Lianmin Zheng committed
717

718
def start_model_process(port):
Lianmin Zheng's avatar
Lianmin Zheng committed
719
720
721
722
723
724
725
726
727
728
    proc = multiprocessing.Process(target=_init_service, args=(port,))
    proc.start()
    time.sleep(1)

    repeat_count = 0
    while repeat_count < 20:
        try:
            con = rpyc.connect(
                "localhost",
                port,
729
                config={"allow_pickle": True, "sync_request_timeout": 1800},
Lianmin Zheng's avatar
Lianmin Zheng committed
730
731
732
733
734
735
736
737
738
739
            )
            break
        except ConnectionRefusedError:
            time.sleep(1)
        repeat_count += 1
    if repeat_count == 20:
        raise RuntimeError("init rpc env error!")

    assert proc.is_alive()
    return con.root, proc