model_rpc.py 25.3 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
import asyncio
import logging
import multiprocessing
import time
Lianmin Zheng's avatar
Lianmin Zheng committed
5
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
6
from concurrent.futures import ThreadPoolExecutor
Cody Yu's avatar
Cody Yu committed
7
from typing import List
Lianmin Zheng's avatar
Lianmin Zheng committed
8
9
10
11
12
13
14

import numpy as np
import rpyc
import torch
from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer
from sglang.srt.constrained.fsm_cache import FSMCache
15
from sglang.srt.constrained.jump_forward import JumpForwardCache
Lianmin Zheng's avatar
Lianmin Zheng committed
16
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
Liangsheng Yin's avatar
Liangsheng Yin committed
17
18
19
from sglang.srt.managers.io_struct import (
    BatchTokenIDOut,
    FlushCacheReq,
20
    TokenizedGenerateReqInput,
Liangsheng Yin's avatar
Liangsheng Yin committed
21
)
Lianmin Zheng's avatar
Lianmin Zheng committed
22
23
24
25
26
27
28
29
30
31
32
33
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
    get_exception_traceback,
    get_int_token_logit_bias,
    is_multimodal_model,
    set_random_seed,
)
34
from vllm.logger import _default_handler as vllm_default_handler
Lianmin Zheng's avatar
Lianmin Zheng committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

logger = logging.getLogger("model_rpc")


class ModelRpcServer(rpyc.Service):
    def exposed_init_model(
        self,
        tp_rank: int,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
        server_args, port_args = [obtain(x) for x in [server_args, port_args]]

        # Copy arguments
        self.model_mode = server_args.model_mode
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
        self.schedule_heuristic = server_args.schedule_heuristic
53
        self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
54
55
56
        vllm_default_handler.setLevel(
            level=getattr(logging, server_args.log_level.upper())
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

        # Init model and tokenizer
        self.model_config = ModelConfig(
            server_args.model_path, server_args.trust_remote_code
        )
        self.model_runner = ModelRunner(
            self.model_config,
            server_args.mem_fraction_static,
            tp_rank,
            server_args.tp_size,
            port_args.nccl_port,
            server_args.load_format,
            server_args.trust_remote_code,
            server_args.model_mode,
        )
        if is_multimodal_model(server_args.model_path):
            self.processor = get_processor(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
            self.tokenizer = self.processor.tokenizer
        else:
            self.tokenizer = get_tokenizer(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
        self.eos_token_id = self.tokenizer.eos_token_id
        self.max_total_num_token = self.model_runner.max_total_num_token
        self.max_num_running_seq = self.max_total_num_token // 2
        self.max_prefill_num_token = max(
89
            self.model_config.context_len,
90
91
92
93
94
            (
                self.max_total_num_token // 6
                if server_args.max_prefill_num_token is None
                else server_args.max_prefill_num_token
            ),
Lianmin Zheng's avatar
Lianmin Zheng committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        )
        self.int_token_logit_bias = torch.tensor(
            get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
        )
        set_random_seed(server_args.random_seed)
        logger.info(
            f"Rank {self.tp_rank}: "
            f"max_total_num_token={self.max_total_num_token}, "
            f"max_prefill_num_token={self.max_prefill_num_token}, "
            f"context_len={self.model_config.context_len}, "
            f"model_mode={self.model_mode}"
        )

        # Init cache
        self.tree_cache = RadixCache(disable="no-cache" in self.model_mode)
Cody Yu's avatar
Cody Yu committed
110
        self.tree_cache_metrics = {"total": 0, "hit": 0}
Lianmin Zheng's avatar
Lianmin Zheng committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        self.scheduler = Scheduler(
            self.schedule_heuristic,
            self.max_num_running_seq,
            self.max_prefill_num_token,
            self.max_total_num_token,
            self.tree_cache,
        )
        self.req_to_token_pool = self.model_runner.req_to_token_pool
        self.token_to_kv_pool = self.model_runner.token_to_kv_pool

        # Init running status
        self.forward_queue: List[Req] = []
        self.running_batch: Batch = None
        self.out_pyobjs = []
        self.decode_forward_ct = 0
126
        self.stream_interval = server_args.stream_interval
Lianmin Zheng's avatar
Lianmin Zheng committed
127
128

        # Init the FSM cache for constrained generation
129
130
131
132
133
134
135
        self.regex_fsm_cache = FSMCache(
            server_args.tokenizer_path,
            {
                "tokenizer_mode": server_args.tokenizer_mode,
                "trust_remote_code": server_args.trust_remote_code,
            },
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
136
        self.jump_forward_cache = JumpForwardCache()
Lianmin Zheng's avatar
Lianmin Zheng committed
137

138
139
140
141
142
        # Init new token estimation
        self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
        self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
        self.new_token_ratio_step = (0.0001, 0.05)  # (down, up)

Liangsheng Yin's avatar
Liangsheng Yin committed
143
144
145
146
147
    def flush_cache(self):
        if len(self.forward_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
Cody Yu's avatar
Cody Yu committed
148
149
            self.tree_cache_metrics = {"total": 0, "hit": 0}
            self.regex_fsm_cache.reset()
Liangsheng Yin's avatar
Liangsheng Yin committed
150
151
152
153
154
155
156
157
158
159
160
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
        else:
            warnings.warn(
                "Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.forward_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
161
162
163
164
165
166
167
168
169
    def exposed_step(self, recv_reqs):
        if self.tp_size != 1:
            recv_reqs = obtain(recv_reqs)

        try:
            # Recv requests
            for recv_req in recv_reqs:
                if isinstance(recv_req, TokenizedGenerateReqInput):
                    self.handle_generate_request(recv_req)
Liangsheng Yin's avatar
Liangsheng Yin committed
170
171
                elif isinstance(recv_req, FlushCacheReq):
                    self.flush_cache()
Lianmin Zheng's avatar
Lianmin Zheng committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
                else:
                    raise ValueError(f"Invalid request: {recv_req}")

            # Forward
            self.forward_step()
        except Exception:
            logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())

        # Return results
        ret = self.out_pyobjs
        self.out_pyobjs = []
        return ret

    @torch.inference_mode()
    def forward_step(self):
        new_batch = self.get_new_fill_batch()

        if new_batch is not None:
            # Run new fill batch
            self.forward_fill_batch(new_batch)

            if not new_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = new_batch
                else:
                    self.running_batch.merge(new_batch)
        else:
            # Run decode batch
            if self.running_batch is not None:
                # Run a few decode batches continuously for reducing overhead
                for _ in range(10):
                    self.forward_decode_batch(self.running_batch)

                    if self.running_batch.is_empty():
                        self.running_batch = None
                        break
208
209
210

                    if self.out_pyobjs and self.running_batch.reqs[0].stream:
                        break
211
212
213
214
215
216
217
            else:
                # check the available size
                available_size = (
                    self.token_to_kv_pool.available_size()
                    + self.tree_cache.evictable_size()
                )
                if available_size != self.max_total_num_token:
Ying Sheng's avatar
Ying Sheng committed
218
                    warnings.warn(
219
220
221
222
                        "Warning: "
                        f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
                        "KV cache pool leak detected!"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
223
224

        if self.running_batch is not None and self.tp_rank == 0:
225
            if self.decode_forward_ct % 20 == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
                num_used = self.max_total_num_token - (
                    self.token_to_kv_pool.available_size()
                    + self.tree_cache.evictable_size()
                )
                logger.info(
                    f"#running-req: {len(self.running_batch.reqs)}, "
                    f"#token: {num_used}, "
                    f"token usage: {num_used / self.max_total_num_token:.2f}, "
                    f"#queue-req: {len(self.forward_queue)}"
                )

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
241
        req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
242
243
        req.pixel_values = recv_req.pixel_values
        if req.pixel_values is not None:
244
            req.pad_value = [
Lianmin Zheng's avatar
Lianmin Zheng committed
245
246
247
248
249
                (recv_req.image_hash) % self.model_config.vocab_size,
                (recv_req.image_hash >> 16) % self.model_config.vocab_size,
                (recv_req.image_hash >> 32) % self.model_config.vocab_size,
                (recv_req.image_hash >> 64) % self.model_config.vocab_size,
            ]
Lianmin Zheng's avatar
Lianmin Zheng committed
250
            req.image_size = recv_req.image_size
Lianmin Zheng's avatar
Lianmin Zheng committed
251
            req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
252
                req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
Lianmin Zheng's avatar
Lianmin Zheng committed
253
254
            )
        req.sampling_params = recv_req.sampling_params
255
256
        req.return_logprob = recv_req.return_logprob
        req.logprob_start_len = recv_req.logprob_start_len
Lianmin Zheng's avatar
Lianmin Zheng committed
257
258
259
        req.stream = recv_req.stream
        req.tokenizer = self.tokenizer

260
261
        # Init regex fsm
        if req.sampling_params.regex is not None:
Cody Yu's avatar
Cody Yu committed
262
            req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
263
            if not self.disable_regex_jump_forward:
Liangsheng Yin's avatar
Liangsheng Yin committed
264
                req.jump_forward_map = self.jump_forward_cache.query(
Liangsheng Yin's avatar
Liangsheng Yin committed
265
266
                    req.sampling_params.regex
                )
267

Lianmin Zheng's avatar
Lianmin Zheng committed
268
269
270
271
272
        # Truncate long prompts
        req.input_ids = req.input_ids[: self.model_config.context_len - 1]
        req.sampling_params.max_new_tokens = min(
            req.sampling_params.max_new_tokens,
            self.model_config.context_len - 1 - len(req.input_ids),
273
            self.max_total_num_token - 128 - len(req.input_ids),
Lianmin Zheng's avatar
Lianmin Zheng committed
274
275
276
277
278
279
280
281
282
283
284
285
        )
        self.forward_queue.append(req)

    def get_new_fill_batch(self):
        if (
            self.running_batch is not None
            and len(self.running_batch.reqs) > self.max_num_running_seq
        ):
            return None

        for req in self.forward_queue:
            prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
286
287
288
            if req.return_logprob:
                prefix_indices = prefix_indices[: req.logprob_start_len]
            req.extend_input_len = len(req.input_ids) - len(prefix_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
            req.prefix_indices = prefix_indices
            req.last_node = last_node

        # Get priority queue
        self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)

        # Add requests if there is available space
        can_run_list = []
        new_batch_total_tokens = 0
        new_batch_input_tokens = 0

        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if self.running_batch:
            available_size -= sum(
                [
306
                    (r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
307
308
309
310
311
                    for r in self.running_batch.reqs
                ]
            )

        for req in self.forward_queue:
312
            if req.return_logprob:
Lianmin Zheng's avatar
Lianmin Zheng committed
313
                # Need at least two tokens to compute normalized logprob
314
315
316
                if req.extend_input_len < 2:
                    delta = 2 - req.extend_input_len
                    req.extend_input_len += delta
Lianmin Zheng's avatar
Lianmin Zheng committed
317
318
319
                    req.prefix_indices = req.prefix_indices[:-delta]
                    if req.image_offset is not None:
                        req.image_offset += delta
320
            if req.extend_input_len == 0 and req.max_new_tokens() > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
321
                # Need at least one token to compute logits
322
                req.extend_input_len = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
323
324
325
326
327
                req.prefix_indices = req.prefix_indices[:-1]
                if req.image_offset is not None:
                    req.image_offset += 1

            if (
328
                req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
329
                < available_size
330
                and req.extend_input_len + new_batch_input_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
331
332
333
334
335
336
                < self.max_prefill_num_token
            ):
                delta = self.tree_cache.inc_ref_counter(req.last_node)
                available_size += delta

                if not (
337
                    req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
338
339
                    < available_size
                ):
340
                    # Undo the insertion
Lianmin Zheng's avatar
Lianmin Zheng committed
341
342
343
                    delta = self.tree_cache.dec_ref_counter(req.last_node)
                    available_size += delta
                else:
344
                    # Add this request to the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
345
346
347
                    self.token_to_kv_pool.add_refs(req.prefix_indices)
                    can_run_list.append(req)
                    new_batch_total_tokens += (
348
                        req.extend_input_len + req.max_new_tokens()
Lianmin Zheng's avatar
Lianmin Zheng committed
349
                    )
350
                    new_batch_input_tokens += req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
351
352
353
354
355

        if len(can_run_list) == 0:
            return None

        if self.tp_rank == 0:
356
357
358
            running_req = (
                0 if self.running_batch is None else len(self.running_batch.reqs)
            )
Cody Yu's avatar
Cody Yu committed
359
            hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
360
361
362
            self.tree_cache_metrics["total"] += (
                hit_tokens + new_batch_input_tokens
            ) / 10**9
Cody Yu's avatar
Cody Yu committed
363
364
365
366
            self.tree_cache_metrics["hit"] += hit_tokens / 10**9
            tree_cache_hit_rate = (
                self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
367
368
            logger.info(
                f"new fill batch. #seq: {len(can_run_list)}. "
Cody Yu's avatar
Cody Yu committed
369
                f"#cached_token: {hit_tokens}. "
Lianmin Zheng's avatar
Lianmin Zheng committed
370
371
                f"#new_token: {new_batch_input_tokens}. "
                f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
Cody Yu's avatar
Cody Yu committed
372
373
374
375
376
377
                f"#running_req: {running_req}. "
                f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
            )
            logger.debug(
                f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
                f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
Liangsheng Yin's avatar
Liangsheng Yin committed
378
379
                f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
                f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
Lianmin Zheng's avatar
Lianmin Zheng committed
380
381
            )

382
        new_batch = Batch.init_new(
Lianmin Zheng's avatar
Lianmin Zheng committed
383
384
385
386
387
388
389
390
391
392
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
        )
        self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
        return new_batch

    def forward_fill_batch(self, batch: Batch):
        # Build batch tensors
393
394
395
396
        batch.prepare_for_extend(
            self.model_config.vocab_size, self.int_token_logit_bias
        )

Cody Yu's avatar
Cody Yu committed
397
        logprobs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
398
399
        if batch.extend_num_tokens != 0:
            # Forward
400
401
402
403
404
405
            logits, (
                prefill_logprobs,
                normalized_logprobs,
                last_logprobs,
            ) = self.model_runner.forward(
                batch, ForwardMode.EXTEND, batch.return_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
406
            )
Cody Yu's avatar
Cody Yu committed
407
408
            if prefill_logprobs is not None:
                logprobs = prefill_logprobs.cpu().tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
409
410
                normalized_logprobs = normalized_logprobs.cpu().tolist()

Cody Yu's avatar
Cody Yu committed
411
            next_token_ids, _ = batch.sample(logits)
Lianmin Zheng's avatar
Lianmin Zheng committed
412
413
414
            next_token_ids = next_token_ids.cpu().tolist()
        else:
            next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
Cody Yu's avatar
Cody Yu committed
415
            logits = logprobs = normalized_logprobs = last_logprobs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
416

Cody Yu's avatar
Cody Yu committed
417
        # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
Lianmin Zheng's avatar
Lianmin Zheng committed
418
        reqs = batch.reqs
Cody Yu's avatar
Cody Yu committed
419
        if last_logprobs is not None:
420
421
422
            last_logprobs = (
                last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
            )
Cody Yu's avatar
Cody Yu committed
423
424

        # Check finish condition
425
426
427
428
        pt = 0
        for i, req in enumerate(reqs):
            req.output_ids = [next_token_ids[i]]
            req.check_finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
429

430
431
432
            if logprobs is not None:
                req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
                req.normalized_logprob = normalized_logprobs[i]
Cody Yu's avatar
Cody Yu committed
433
434
435
436

                token_ids = req.input_ids + [next_token_ids[i]]
                token_logprobs = [None] + req.logprob + [last_logprobs[i]]
                req.token_logprob = list(zip(token_ids, token_logprobs))
437
                pt += req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
438
439
440
441

        self.handle_finished_requests(batch)

    def forward_decode_batch(self, batch: Batch):
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
        # check if decode out of memory
        if not batch.check_decode_mem():
            old_ratio = self.new_token_ratio
            self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)

            retracted_reqs = batch.retract_decode()
            logger.info(
                "decode out of memory happened, "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.forward_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
                self.new_token_ratio - self.new_token_ratio_step[0],
                self.min_new_token_ratio,
            )

460
        if not self.disable_regex_jump_forward:
Liangsheng Yin's avatar
Liangsheng Yin committed
461
462
            # check for jump-forward
            jump_forward_reqs = batch.check_for_jump_forward()
463

Liangsheng Yin's avatar
Liangsheng Yin committed
464
465
            # check for image jump-forward
            for req in jump_forward_reqs:
466
467
468
469
470
471
472
473
474
475
476
                if req.pixel_values is not None:
                    (
                        req.input_ids,
                        req.image_offset,
                    ) = self.model_runner.model.pad_input_ids(
                        req.input_ids,
                        req.pad_value,
                        req.pixel_values.shape,
                        req.image_size,
                    )

Liangsheng Yin's avatar
Liangsheng Yin committed
477
            self.forward_queue.extend(jump_forward_reqs)
Liangsheng Yin's avatar
Liangsheng Yin committed
478
479
480
            if batch.is_empty():
                return

Lianmin Zheng's avatar
Lianmin Zheng committed
481
        # Update batch tensors
482
        self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
483
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
484
485

        # Forward
Cody Yu's avatar
Cody Yu committed
486
487
488
489
490
491
        logits, (_, _, last_logprobs) = self.model_runner.forward(
            batch,
            ForwardMode.DECODE,
            batch.return_logprob,
        )
        next_token_ids, _ = batch.sample(logits)
Lianmin Zheng's avatar
Lianmin Zheng committed
492
493
        next_token_ids = next_token_ids.cpu().tolist()

Cody Yu's avatar
Cody Yu committed
494
        # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
Lianmin Zheng's avatar
Lianmin Zheng committed
495
        reqs = batch.reqs
Cody Yu's avatar
Cody Yu committed
496
        if last_logprobs is not None:
497
498
499
            last_logprobs = last_logprobs[
                torch.arange(len(reqs)), next_token_ids
            ].tolist()
Cody Yu's avatar
Cody Yu committed
500
501
502
503
504
505
506
507

        # Check finish condition
        for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
            req.output_ids.append(next_tok_id)
            req.check_finished()

            if last_logprobs is not None:
                req.token_logprob.append((next_tok_id, last_logprobs[i]))
Lianmin Zheng's avatar
Lianmin Zheng committed
508
509
510
511
512
513

        self.handle_finished_requests(batch)

    def handle_finished_requests(self, batch: Batch):
        output_rids = []
        output_tokens = []
Liangsheng Yin's avatar
Liangsheng Yin committed
514
        output_and_jump_forward_strs = []
Lianmin Zheng's avatar
Lianmin Zheng committed
515
516
517
518
519
520
521
522
523
524
525
526
527
        output_hit_stop_str = []
        output_skip_special_tokens = []
        output_meta_info = []
        output_finished = []
        finished_indices = []
        unfinished_indices = []
        for i, req in enumerate(batch.reqs):
            if req.finished:
                finished_indices.append(i)
            else:
                unfinished_indices.append(i)

            if req.finished or (
528
529
530
531
532
533
534
                (
                    req.stream
                    and (
                        self.decode_forward_ct % self.stream_interval == 0
                        or len(req.output_ids) == 1
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
535
536
537
            ):
                output_rids.append(req.rid)
                output_tokens.append(req.output_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
538
                output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
Lianmin Zheng's avatar
Lianmin Zheng committed
539
540
541
542
                output_hit_stop_str.append(req.hit_stop_str)
                output_skip_special_tokens.append(
                    req.sampling_params.skip_special_tokens
                )
543

544
545
                # For the length of input_ids, which will be accumulated during jump-forward.
                # Use the original length of input_ids to calculate the token usage info.
Lianmin Zheng's avatar
Lianmin Zheng committed
546
                meta_info = {
547
548
549
550
                    "prompt_tokens": req.orig_prompt_tokens,
                    "completion_tokens": len(req.input_ids)
                    + len(req.output_ids)
                    - req.orig_prompt_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
551
                }
552

553
554
                if req.return_logprob:
                    meta_info["prompt_logprob"] = req.logprob
Cody Yu's avatar
Cody Yu committed
555
                    meta_info["token_logprob"] = req.token_logprob
556
                    meta_info["normalized_prompt_logprob"] = req.normalized_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
557
558
559
560
561
562
563
564
565
                output_meta_info.append(meta_info)
                output_finished.append(req.finished)

        # Send to detokenizer
        if output_rids:
            self.out_pyobjs.append(
                BatchTokenIDOut(
                    output_rids,
                    output_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
566
                    output_and_jump_forward_strs,
Lianmin Zheng's avatar
Lianmin Zheng committed
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
                    output_hit_stop_str,
                    output_skip_special_tokens,
                    output_meta_info,
                    output_finished,
                )
            )

        # Remove finished reqs
        if finished_indices:
            # Update radix cache
            req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
            for i in finished_indices:
                req = batch.reqs[i]
                req_pool_idx = req_pool_indices_cpu[i]
                token_ids = tuple(req.input_ids + req.output_ids)
                seq_len = len(token_ids) - 1
                indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
584
585
586
                prefix_len = self.tree_cache.insert(
                    token_ids[:seq_len], indices.clone()
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642

                self.token_to_kv_pool.free(indices[:prefix_len])
                self.req_to_token_pool.free(req_pool_idx)
                self.tree_cache.dec_ref_counter(req.last_node)

            # Update batch tensors
            if unfinished_indices:
                batch.filter_batch(unfinished_indices)
            else:
                batch.reqs = []


class ModelRpcClient:
    def __init__(self, server_args: ServerArgs, port_args: PortArgs):
        tp_size = server_args.tp_size

        if tp_size == 1:
            # Init model
            self.model_server = ModelRpcServer()
            self.model_server.exposed_init_model(0, server_args, port_args)

            # Wrap functions
            def async_wrap(f):
                async def _func(*args, **kwargs):
                    return f(*args, **kwargs)

                return _func

            self.step = async_wrap(self.model_server.exposed_step)
        else:
            with ThreadPoolExecutor(tp_size) as executor:
                # Launch model processes
                rets = executor.map(start_model_process, port_args.model_rpc_ports)
                self.model_servers = [x[0] for x in rets]
                self.procs = [x[1] for x in rets]

                # Init model
                def init_model(i):
                    return self.model_servers[i].init_model(i, server_args, port_args)

                rets = [obtain(x) for x in executor.map(init_model, range(tp_size))]

            # Wrap functions
            def async_wrap(func_name):
                fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]

                async def _func(*args, **kwargs):
                    tasks = [f(*args, **kwargs) for f in fs]
                    await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
                    return obtain(tasks[0].value)

                return _func

            self.step = async_wrap("step")


643
644
645
646
647
648
649
650
def _init_service(port):
    t = ThreadedServer(
        ModelRpcServer(),
        port=port,
        protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
    )
    t.start()

Lianmin Zheng's avatar
Lianmin Zheng committed
651

652
def start_model_process(port):
Lianmin Zheng's avatar
Lianmin Zheng committed
653
654
655
656
657
658
659
660
661
662
    proc = multiprocessing.Process(target=_init_service, args=(port,))
    proc.start()
    time.sleep(1)

    repeat_count = 0
    while repeat_count < 20:
        try:
            con = rpyc.connect(
                "localhost",
                port,
663
                config={"allow_pickle": True, "sync_request_timeout": 1800},
Lianmin Zheng's avatar
Lianmin Zheng committed
664
665
666
667
668
669
670
671
672
673
            )
            break
        except ConnectionRefusedError:
            time.sleep(1)
        repeat_count += 1
    if repeat_count == 20:
        raise RuntimeError("init rpc env error!")

    assert proc.is_alive()
    return con.root, proc