model_rpc.py 28.7 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
import asyncio
import logging
import multiprocessing
import time
Lianmin Zheng's avatar
Lianmin Zheng committed
5
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
6
from concurrent.futures import ThreadPoolExecutor
Cody Yu's avatar
Cody Yu committed
7
from typing import List
Lianmin Zheng's avatar
Lianmin Zheng committed
8
9
10
11
12

import rpyc
import torch
from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer
13
14
15
16
try:
    from vllm.logger import _default_handler as vllm_default_logger
except ImportError:
    from vllm.logger import logger as vllm_default_logger
Liangsheng Yin's avatar
Liangsheng Yin committed
17

Lianmin Zheng's avatar
Lianmin Zheng committed
18
from sglang.srt.constrained.fsm_cache import FSMCache
19
from sglang.srt.constrained.jump_forward import JumpForwardCache
Lianmin Zheng's avatar
Lianmin Zheng committed
20
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
Liangsheng Yin's avatar
Liangsheng Yin committed
21
22
23
from sglang.srt.managers.io_struct import (
    BatchTokenIDOut,
    FlushCacheReq,
24
    TokenizedGenerateReqInput,
Liangsheng Yin's avatar
Liangsheng Yin committed
25
)
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
28
29
30
31
32
33
34
35
36
37
38
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
    get_exception_traceback,
    get_int_token_logit_bias,
    is_multimodal_model,
    set_random_seed,
)

39

Lianmin Zheng's avatar
Lianmin Zheng committed
40
logger = logging.getLogger("model_rpc")
41
vllm_default_logger.setLevel(logging.WARN)
Lianmin Zheng's avatar
Lianmin Zheng committed
42
logging.getLogger("vllm.utils").setLevel(logging.WARN)
Lianmin Zheng's avatar
Lianmin Zheng committed
43
44


45
46
class ModelRpcServer:
    def __init__(
Lianmin Zheng's avatar
Lianmin Zheng committed
47
48
49
50
51
52
53
54
55
56
57
        self,
        tp_rank: int,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
        server_args, port_args = [obtain(x) for x in [server_args, port_args]]

        # Copy arguments
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
        self.schedule_heuristic = server_args.schedule_heuristic
58
        self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
Lianmin Zheng's avatar
Lianmin Zheng committed
59
60
61

        # Init model and tokenizer
        self.model_config = ModelConfig(
Liangsheng Yin's avatar
Liangsheng Yin committed
62
63
64
            server_args.model_path,
            server_args.trust_remote_code,
            context_length=server_args.context_length,
Lianmin Zheng's avatar
Lianmin Zheng committed
65
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
66

67
        # For model end global settings
Liangsheng Yin's avatar
Liangsheng Yin committed
68
69
70
71
72
        server_args_dict = {
            "enable_flashinfer": server_args.enable_flashinfer,
            "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
        }

Lianmin Zheng's avatar
Lianmin Zheng committed
73
        self.model_runner = ModelRunner(
Liangsheng Yin's avatar
Liangsheng Yin committed
74
75
76
77
78
79
80
            model_config=self.model_config,
            mem_fraction_static=server_args.mem_fraction_static,
            tp_rank=tp_rank,
            tp_size=server_args.tp_size,
            nccl_port=port_args.nccl_port,
            load_format=server_args.load_format,
            trust_remote_code=server_args.trust_remote_code,
Liangsheng Yin's avatar
Liangsheng Yin committed
81
            server_args_dict=server_args_dict,
Lianmin Zheng's avatar
Lianmin Zheng committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        )
        if is_multimodal_model(server_args.model_path):
            self.processor = get_processor(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
            self.tokenizer = self.processor.tokenizer
        else:
            self.tokenizer = get_tokenizer(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
        self.max_total_num_token = self.model_runner.max_total_num_token
        self.max_num_running_seq = self.max_total_num_token // 2
        self.max_prefill_num_token = max(
99
            self.model_config.context_len,
100
101
102
103
104
            (
                self.max_total_num_token // 6
                if server_args.max_prefill_num_token is None
                else server_args.max_prefill_num_token
            ),
Lianmin Zheng's avatar
Lianmin Zheng committed
105
106
107
108
109
110
111
112
113
114
115
        )
        self.int_token_logit_bias = torch.tensor(
            get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
        )
        set_random_seed(server_args.random_seed)
        logger.info(
            f"Rank {self.tp_rank}: "
            f"max_total_num_token={self.max_total_num_token}, "
            f"max_prefill_num_token={self.max_prefill_num_token}, "
            f"context_len={self.model_config.context_len}, "
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
116
117
        if self.tp_rank == 0:
            logger.info(f"server_args: {server_args.print_mode_args()}")
Lianmin Zheng's avatar
Lianmin Zheng committed
118
119

        # Init cache
Liangsheng Yin's avatar
Liangsheng Yin committed
120
121
122
123
124
        self.tree_cache = RadixCache(
            req_to_token_pool=self.model_runner.req_to_token_pool,
            token_to_kv_pool=self.model_runner.token_to_kv_pool,
            disable=server_args.disable_radix_cache,
        )
Cody Yu's avatar
Cody Yu committed
125
        self.tree_cache_metrics = {"total": 0, "hit": 0}
Lianmin Zheng's avatar
Lianmin Zheng committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        self.scheduler = Scheduler(
            self.schedule_heuristic,
            self.max_num_running_seq,
            self.max_prefill_num_token,
            self.max_total_num_token,
            self.tree_cache,
        )
        self.req_to_token_pool = self.model_runner.req_to_token_pool
        self.token_to_kv_pool = self.model_runner.token_to_kv_pool

        # Init running status
        self.forward_queue: List[Req] = []
        self.running_batch: Batch = None
        self.out_pyobjs = []
        self.decode_forward_ct = 0
141
        self.stream_interval = server_args.stream_interval
Lianmin Zheng's avatar
Lianmin Zheng committed
142
143
        self.num_generated_tokens = 0
        self.last_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
144
145

        # Init the FSM cache for constrained generation
146
147
148
149
150
151
152
        self.regex_fsm_cache = FSMCache(
            server_args.tokenizer_path,
            {
                "tokenizer_mode": server_args.tokenizer_mode,
                "trust_remote_code": server_args.trust_remote_code,
            },
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
153
        self.jump_forward_cache = JumpForwardCache()
Lianmin Zheng's avatar
Lianmin Zheng committed
154

155
156
157
158
159
        # Init new token estimation
        self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
        self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
        self.new_token_ratio_step = (0.0001, 0.05)  # (down, up)

Liangsheng Yin's avatar
Liangsheng Yin committed
160
161
162
163
164
    def flush_cache(self):
        if len(self.forward_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
Cody Yu's avatar
Cody Yu committed
165
166
            self.tree_cache_metrics = {"total": 0, "hit": 0}
            self.regex_fsm_cache.reset()
Liangsheng Yin's avatar
Liangsheng Yin committed
167
168
169
170
171
172
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
        else:
            warnings.warn(
173
                f"Cache not flushed because there are pending requests. "
Liangsheng Yin's avatar
Liangsheng Yin committed
174
175
176
177
                f"#queue-req: {len(self.forward_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
178
179
180
181
182
183
184
185
186
    def exposed_step(self, recv_reqs):
        if self.tp_size != 1:
            recv_reqs = obtain(recv_reqs)

        try:
            # Recv requests
            for recv_req in recv_reqs:
                if isinstance(recv_req, TokenizedGenerateReqInput):
                    self.handle_generate_request(recv_req)
Liangsheng Yin's avatar
Liangsheng Yin committed
187
188
                elif isinstance(recv_req, FlushCacheReq):
                    self.flush_cache()
Lianmin Zheng's avatar
Lianmin Zheng committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
                else:
                    raise ValueError(f"Invalid request: {recv_req}")

            # Forward
            self.forward_step()
        except Exception:
            logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())

        # Return results
        ret = self.out_pyobjs
        self.out_pyobjs = []
        return ret

    @torch.inference_mode()
    def forward_step(self):
        new_batch = self.get_new_fill_batch()

        if new_batch is not None:
            # Run new fill batch
            self.forward_fill_batch(new_batch)

Liangsheng Yin's avatar
Liangsheng Yin committed
210
211
            self.cache_filled_batch(new_batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
212
213
214
215
216
217
218
219
220
221
            if not new_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = new_batch
                else:
                    self.running_batch.merge(new_batch)
        else:
            # Run decode batch
            if self.running_batch is not None:
                # Run a few decode batches continuously for reducing overhead
                for _ in range(10):
Lianmin Zheng's avatar
Lianmin Zheng committed
222
                    self.num_generated_tokens += len(self.running_batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
223
224
225
226
227
                    self.forward_decode_batch(self.running_batch)

                    if self.running_batch.is_empty():
                        self.running_batch = None
                        break
228
229
230

                    if self.out_pyobjs and self.running_batch.reqs[0].stream:
                        break
Lianmin Zheng's avatar
Lianmin Zheng committed
231
232
233
234
235
236
237

                    if self.running_batch is not None and self.tp_rank == 0:
                        if self.decode_forward_ct % 40 == 0:
                            num_used = self.max_total_num_token - (
                                self.token_to_kv_pool.available_size()
                                + self.tree_cache.evictable_size()
                            )
Lianmin Zheng's avatar
Lianmin Zheng committed
238
239
240
                            throuhgput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
                            self.num_generated_tokens = 0
                            self.last_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
241
242
243
244
                            logger.info(
                                f"#running-req: {len(self.running_batch.reqs)}, "
                                f"#token: {num_used}, "
                                f"token usage: {num_used / self.max_total_num_token:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
245
                                f"gen throughput (token/s): {throuhgput:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
246
247
                                f"#queue-req: {len(self.forward_queue)}"
                            )
248
249
250
251
252
253
254
            else:
                # check the available size
                available_size = (
                    self.token_to_kv_pool.available_size()
                    + self.tree_cache.evictable_size()
                )
                if available_size != self.max_total_num_token:
Ying Sheng's avatar
Ying Sheng committed
255
                    warnings.warn(
256
257
258
259
                        "Warning: "
                        f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
                        "KV cache pool leak detected!"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
260
261
262
263
264

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
265
        req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
266
267
        req.pixel_values = recv_req.pixel_values
        if req.pixel_values is not None:
268
            req.pad_value = [
Lianmin Zheng's avatar
Lianmin Zheng committed
269
270
271
272
273
                (recv_req.image_hash) % self.model_config.vocab_size,
                (recv_req.image_hash >> 16) % self.model_config.vocab_size,
                (recv_req.image_hash >> 32) % self.model_config.vocab_size,
                (recv_req.image_hash >> 64) % self.model_config.vocab_size,
            ]
Lianmin Zheng's avatar
Lianmin Zheng committed
274
            req.image_size = recv_req.image_size
Lianmin Zheng's avatar
Lianmin Zheng committed
275
            req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
276
                req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
Lianmin Zheng's avatar
Lianmin Zheng committed
277
278
            )
        req.sampling_params = recv_req.sampling_params
279
280
        req.return_logprob = recv_req.return_logprob
        req.logprob_start_len = recv_req.logprob_start_len
Liangsheng Yin's avatar
Liangsheng Yin committed
281
        req.top_logprobs_num = recv_req.top_logprobs_num
Lianmin Zheng's avatar
Lianmin Zheng committed
282
283
284
        req.stream = recv_req.stream
        req.tokenizer = self.tokenizer

285
286
        # Init regex fsm
        if req.sampling_params.regex is not None:
Cody Yu's avatar
Cody Yu committed
287
            req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
288
            if not self.disable_regex_jump_forward:
Liangsheng Yin's avatar
Liangsheng Yin committed
289
                req.jump_forward_map = self.jump_forward_cache.query(
Liangsheng Yin's avatar
Liangsheng Yin committed
290
291
                    req.sampling_params.regex
                )
292

Lianmin Zheng's avatar
Lianmin Zheng committed
293
294
295
296
297
        # Truncate long prompts
        req.input_ids = req.input_ids[: self.model_config.context_len - 1]
        req.sampling_params.max_new_tokens = min(
            req.sampling_params.max_new_tokens,
            self.model_config.context_len - 1 - len(req.input_ids),
298
            self.max_total_num_token - 128 - len(req.input_ids),
Lianmin Zheng's avatar
Lianmin Zheng committed
299
300
301
302
303
304
305
306
307
308
309
310
        )
        self.forward_queue.append(req)

    def get_new_fill_batch(self):
        if (
            self.running_batch is not None
            and len(self.running_batch.reqs) > self.max_num_running_seq
        ):
            return None

        for req in self.forward_queue:
            prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
311
312
313
            if req.return_logprob:
                prefix_indices = prefix_indices[: req.logprob_start_len]
            req.extend_input_len = len(req.input_ids) - len(prefix_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
            req.prefix_indices = prefix_indices
            req.last_node = last_node

        # Get priority queue
        self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)

        # Add requests if there is available space
        can_run_list = []
        new_batch_total_tokens = 0
        new_batch_input_tokens = 0

        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if self.running_batch:
            available_size -= sum(
                [
331
                    (r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
332
333
334
335
336
                    for r in self.running_batch.reqs
                ]
            )

        for req in self.forward_queue:
337
            if req.return_logprob:
Lianmin Zheng's avatar
Lianmin Zheng committed
338
                # Need at least two tokens to compute normalized logprob
339
340
341
                if req.extend_input_len < 2:
                    delta = 2 - req.extend_input_len
                    req.extend_input_len += delta
Lianmin Zheng's avatar
Lianmin Zheng committed
342
343
344
                    req.prefix_indices = req.prefix_indices[:-delta]
                    if req.image_offset is not None:
                        req.image_offset += delta
345
            if req.extend_input_len == 0 and req.max_new_tokens() > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
346
                # Need at least one token to compute logits
347
                req.extend_input_len = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
348
349
350
351
352
                req.prefix_indices = req.prefix_indices[:-1]
                if req.image_offset is not None:
                    req.image_offset += 1

            if (
353
                req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
354
                < available_size
355
                and req.extend_input_len + new_batch_input_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
356
357
                < self.max_prefill_num_token
            ):
Liangsheng Yin's avatar
Liangsheng Yin committed
358
                delta = self.tree_cache.inc_lock_ref(req.last_node)
Lianmin Zheng's avatar
Lianmin Zheng committed
359
360
361
                available_size += delta

                if not (
362
                    req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
363
364
                    < available_size
                ):
Liangsheng Yin's avatar
Liangsheng Yin committed
365
366
                    # Undo locking
                    delta = self.tree_cache.dec_lock_ref(req.last_node)
Lianmin Zheng's avatar
Lianmin Zheng committed
367
                    available_size += delta
368
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
369
                else:
370
                    # Add this request to the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
371
372
                    can_run_list.append(req)
                    new_batch_total_tokens += (
373
                        req.extend_input_len + req.max_new_tokens()
Lianmin Zheng's avatar
Lianmin Zheng committed
374
                    )
375
                    new_batch_input_tokens += req.extend_input_len
376
377
            else:
                break
Lianmin Zheng's avatar
Lianmin Zheng committed
378
379
380
381
        if len(can_run_list) == 0:
            return None

        if self.tp_rank == 0:
382
383
384
            running_req = (
                0 if self.running_batch is None else len(self.running_batch.reqs)
            )
Cody Yu's avatar
Cody Yu committed
385
            hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
386
387
388
            self.tree_cache_metrics["total"] += (
                hit_tokens + new_batch_input_tokens
            ) / 10**9
Cody Yu's avatar
Cody Yu committed
389
390
391
392
            self.tree_cache_metrics["hit"] += hit_tokens / 10**9
            tree_cache_hit_rate = (
                self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
393
394
            logger.info(
                f"new fill batch. #seq: {len(can_run_list)}. "
Cody Yu's avatar
Cody Yu committed
395
                f"#cached_token: {hit_tokens}. "
Lianmin Zheng's avatar
Lianmin Zheng committed
396
397
                f"#new_token: {new_batch_input_tokens}. "
                f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
Cody Yu's avatar
Cody Yu committed
398
399
400
                f"#running_req: {running_req}. "
                f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
            )
401
402
403
404
405
406
            #logger.debug(
            #    f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
            #    f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
            #    f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
            #    f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
            #)
Lianmin Zheng's avatar
Lianmin Zheng committed
407

408
        new_batch = Batch.init_new(
Lianmin Zheng's avatar
Lianmin Zheng committed
409
410
411
412
413
414
415
416
417
418
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
        )
        self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
        return new_batch

    def forward_fill_batch(self, batch: Batch):
        # Build batch tensors
419
420
421
422
        batch.prepare_for_extend(
            self.model_config.vocab_size, self.int_token_logit_bias
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
423
424
        if batch.extend_num_tokens != 0:
            # Forward
425
            logits, (
Liangsheng Yin's avatar
Liangsheng Yin committed
426
                prefill_token_logprobs,
427
                normalized_prompt_logprobs,
Liangsheng Yin's avatar
Liangsheng Yin committed
428
429
                prefill_top_logprobs,
                decode_top_logprobs,
430
                last_logprobs,
431
            ) = self.model_runner.forward(batch, ForwardMode.EXTEND)
Liangsheng Yin's avatar
Liangsheng Yin committed
432
            if prefill_token_logprobs is not None:
433
434
                prefill_token_logprobs = prefill_token_logprobs.tolist()
                normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
435

Cody Yu's avatar
Cody Yu committed
436
            next_token_ids, _ = batch.sample(logits)
437
438
439
440

            # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
            if last_logprobs is not None:
                last_token_logprobs = (
Lianmin Zheng's avatar
Lianmin Zheng committed
441
442
443
                    last_logprobs[
                        torch.arange(len(batch.reqs), device=next_token_ids.device),
                        next_token_ids].tolist()
444
445
446
                )

            next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
447
448
        else:
            next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
Cody Yu's avatar
Cody Yu committed
449
450

        # Check finish condition
451
        pt = 0
452
        for i, req in enumerate(batch.reqs):
453
            req.completion_tokens_wo_jump_forward += 1
454
455
            req.output_ids = [next_token_ids[i]]
            req.check_finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
456

457
458
459
            if req.return_logprob:
                req.normalized_prompt_logprob = normalized_prompt_logprobs[i]

Liangsheng Yin's avatar
Liangsheng Yin committed
460
461
462
463
464
465
466
                # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
                req.prefill_token_logprobs = list(
                    zip(
                        prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
                        req.input_ids[-req.extend_input_len + 1 :],
                    )
                )
467
                if req.logprob_start_len == 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
468
469
470
471
472
473
                    req.prefill_token_logprobs = [
                        (None, req.input_ids[0])
                    ] + req.prefill_token_logprobs
                req.decode_token_logprobs = [
                    (last_token_logprobs[i], next_token_ids[i])
                ]
474
475

            if req.top_logprobs_num > 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
476
477
478
479
                req.prefill_top_logprobs = prefill_top_logprobs[i]
                if req.logprob_start_len == 0:
                    req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
                req.decode_top_logprobs = [decode_top_logprobs[i]]
480
481

            pt += req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
482
483
484

        self.handle_finished_requests(batch)

Liangsheng Yin's avatar
Liangsheng Yin committed
485
486
487
488
489
490
491
492
493
494
495
496
    def cache_filled_batch(self, batch: Batch):
        req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
        for i, req in enumerate(batch.reqs):
            new_prefix_indices, new_last_node = self.tree_cache.cache_req(
                token_ids=tuple(req.input_ids + req.output_ids)[:-1],
                last_uncached_pos=len(req.prefix_indices),
                req_pool_idx=req_pool_indices_cpu[i],
                del_in_memory_pool=False,
                old_last_node=req.last_node,
            )
            req.prefix_indices, req.last_node = new_prefix_indices, new_last_node

Lianmin Zheng's avatar
Lianmin Zheng committed
497
    def forward_decode_batch(self, batch: Batch):
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        # check if decode out of memory
        if not batch.check_decode_mem():
            old_ratio = self.new_token_ratio
            self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)

            retracted_reqs = batch.retract_decode()
            logger.info(
                "decode out of memory happened, "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.forward_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
                self.new_token_ratio - self.new_token_ratio_step[0],
                self.min_new_token_ratio,
            )

516
        if not self.disable_regex_jump_forward:
Liangsheng Yin's avatar
Liangsheng Yin committed
517
518
            # check for jump-forward
            jump_forward_reqs = batch.check_for_jump_forward()
519

Liangsheng Yin's avatar
Liangsheng Yin committed
520
521
            # check for image jump-forward
            for req in jump_forward_reqs:
522
523
524
525
526
527
528
529
530
531
532
                if req.pixel_values is not None:
                    (
                        req.input_ids,
                        req.image_offset,
                    ) = self.model_runner.model.pad_input_ids(
                        req.input_ids,
                        req.pad_value,
                        req.pixel_values.shape,
                        req.image_size,
                    )

Liangsheng Yin's avatar
Liangsheng Yin committed
533
            self.forward_queue.extend(jump_forward_reqs)
Liangsheng Yin's avatar
Liangsheng Yin committed
534
535
536
            if batch.is_empty():
                return

Lianmin Zheng's avatar
Lianmin Zheng committed
537
        # Update batch tensors
538
        self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
539
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
540
541

        # Forward
542
543
544
545
        logits, (
            _,
            _,
            _,
546
            decode_top_logprobs,
547
548
            last_logprobs,
        ) = self.model_runner.forward(batch, ForwardMode.DECODE)
Cody Yu's avatar
Cody Yu committed
549
        next_token_ids, _ = batch.sample(logits)
550
        next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
551

Cody Yu's avatar
Cody Yu committed
552
553
        # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
        if last_logprobs is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
554
            new_token_logprobs = last_logprobs[
555
                torch.arange(len(batch.reqs)), next_token_ids
556
            ].tolist()
Cody Yu's avatar
Cody Yu committed
557
558

        # Check finish condition
559
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
560
            req.completion_tokens_wo_jump_forward += 1
Liangsheng Yin's avatar
Liangsheng Yin committed
561
            req.output_ids.append(next_token_id)
Cody Yu's avatar
Cody Yu committed
562
563
            req.check_finished()

564
            if req.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
565
                req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
566
567

            if req.top_logprobs_num > 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
568
                req.decode_top_logprobs.append(decode_top_logprobs[i])
Lianmin Zheng's avatar
Lianmin Zheng committed
569
570
571
572
573
574

        self.handle_finished_requests(batch)

    def handle_finished_requests(self, batch: Batch):
        output_rids = []
        output_tokens = []
Liangsheng Yin's avatar
Liangsheng Yin committed
575
        output_and_jump_forward_strs = []
Lianmin Zheng's avatar
Lianmin Zheng committed
576
577
        output_hit_stop_str = []
        output_skip_special_tokens = []
578
        output_spaces_between_special_tokens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
579
580
581
582
583
584
585
586
587
588
589
        output_meta_info = []
        output_finished = []
        finished_indices = []
        unfinished_indices = []
        for i, req in enumerate(batch.reqs):
            if req.finished:
                finished_indices.append(i)
            else:
                unfinished_indices.append(i)

            if req.finished or (
590
591
592
593
594
595
596
                (
                    req.stream
                    and (
                        self.decode_forward_ct % self.stream_interval == 0
                        or len(req.output_ids) == 1
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
597
598
599
            ):
                output_rids.append(req.rid)
                output_tokens.append(req.output_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
600
                output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
Lianmin Zheng's avatar
Lianmin Zheng committed
601
602
603
604
                output_hit_stop_str.append(req.hit_stop_str)
                output_skip_special_tokens.append(
                    req.sampling_params.skip_special_tokens
                )
605
606
607
                output_spaces_between_special_tokens.append(
                    req.sampling_params.spaces_between_special_tokens
                )
608

Lianmin Zheng's avatar
Lianmin Zheng committed
609
                meta_info = {
610
                    "prompt_tokens": req.prompt_tokens,
611
612
                    "completion_tokens": len(req.input_ids)
                    + len(req.output_ids)
613
                    - req.prompt_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
614
                    "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
615
                    "finish_reason": str(req.finish_reason),  # FIXME: convert to the correct string
Lianmin Zheng's avatar
Lianmin Zheng committed
616
                    "hit_stop_str": req.hit_stop_str,
Lianmin Zheng's avatar
Lianmin Zheng committed
617
                }
618
                if req.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
619
620
621
622
623
624
625
626
627
628
629
630
631
                    (
                        meta_info["prefill_token_logprobs"],
                        meta_info["decode_token_logprobs"],
                        meta_info["prefill_top_logprobs"],
                        meta_info["decode_top_logprobs"],
                        meta_info["normalized_prompt_logprob"],
                    ) = (
                        req.prefill_token_logprobs,
                        req.decode_token_logprobs,
                        req.prefill_top_logprobs,
                        req.decode_top_logprobs,
                        req.normalized_prompt_logprob,
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
632
633
634
635
636
637
638
639
640
                output_meta_info.append(meta_info)
                output_finished.append(req.finished)

        # Send to detokenizer
        if output_rids:
            self.out_pyobjs.append(
                BatchTokenIDOut(
                    output_rids,
                    output_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
641
                    output_and_jump_forward_strs,
Lianmin Zheng's avatar
Lianmin Zheng committed
642
643
                    output_hit_stop_str,
                    output_skip_special_tokens,
644
                    output_spaces_between_special_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
645
646
647
648
649
650
651
652
                    output_meta_info,
                    output_finished,
                )
            )

        # Remove finished reqs
        if finished_indices:
            # Update radix cache
653
            req_pool_indices_cpu = batch.req_pool_indices.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
654
655
            for i in finished_indices:
                req = batch.reqs[i]
Liangsheng Yin's avatar
Liangsheng Yin committed
656
657
658
659
                self.tree_cache.cache_req(
                    token_ids=tuple(req.input_ids + req.output_ids)[:-1],
                    last_uncached_pos=len(req.prefix_indices),
                    req_pool_idx=req_pool_indices_cpu[i],
660
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
661

Liangsheng Yin's avatar
Liangsheng Yin committed
662
                self.tree_cache.dec_lock_ref(req.last_node)
Lianmin Zheng's avatar
Lianmin Zheng committed
663
664
665
666
667
668
669
670

            # Update batch tensors
            if unfinished_indices:
                batch.filter_batch(unfinished_indices)
            else:
                batch.reqs = []


671
672
673
674
class ModelRpcService(rpyc.Service):
    exposed_ModelRpcServer = ModelRpcServer


Lianmin Zheng's avatar
Lianmin Zheng committed
675
676
677
678
679
680
class ModelRpcClient:
    def __init__(self, server_args: ServerArgs, port_args: PortArgs):
        tp_size = server_args.tp_size

        if tp_size == 1:
            # Init model
681
682
683
            self.model_server = ModelRpcService().exposed_ModelRpcServer(
                0, server_args, port_args
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
684
685
686
687
688
689
690
691
692
693
694
695
696

            # Wrap functions
            def async_wrap(f):
                async def _func(*args, **kwargs):
                    return f(*args, **kwargs)

                return _func

            self.step = async_wrap(self.model_server.exposed_step)
        else:
            with ThreadPoolExecutor(tp_size) as executor:
                # Launch model processes
                rets = executor.map(start_model_process, port_args.model_rpc_ports)
697
                self.remote_services = [x[0] for x in rets]
Lianmin Zheng's avatar
Lianmin Zheng committed
698
699
700
701
                self.procs = [x[1] for x in rets]

                # Init model
                def init_model(i):
702
703
704
                    return self.remote_services[i].ModelRpcServer(
                        i, server_args, port_args
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
705

706
                self.model_servers = executor.map(init_model, range(tp_size))
Lianmin Zheng's avatar
Lianmin Zheng committed
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721

            # Wrap functions
            def async_wrap(func_name):
                fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]

                async def _func(*args, **kwargs):
                    tasks = [f(*args, **kwargs) for f in fs]
                    await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
                    return obtain(tasks[0].value)

                return _func

            self.step = async_wrap("step")


722
723
def _init_service(port):
    t = ThreadedServer(
724
        ModelRpcService(),
725
726
727
728
729
        port=port,
        protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
    )
    t.start()

Lianmin Zheng's avatar
Lianmin Zheng committed
730

731
def start_model_process(port):
Lianmin Zheng's avatar
Lianmin Zheng committed
732
733
734
735
736
737
738
739
740
741
    proc = multiprocessing.Process(target=_init_service, args=(port,))
    proc.start()
    time.sleep(1)

    repeat_count = 0
    while repeat_count < 20:
        try:
            con = rpyc.connect(
                "localhost",
                port,
742
                config={"allow_pickle": True, "sync_request_timeout": 1800},
Lianmin Zheng's avatar
Lianmin Zheng committed
743
744
745
746
747
748
749
750
751
752
            )
            break
        except ConnectionRefusedError:
            time.sleep(1)
        repeat_count += 1
    if repeat_count == 20:
        raise RuntimeError("init rpc env error!")

    assert proc.is_alive()
    return con.root, proc