tp_worker.py 30.6 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
"""A tensor parallel worker."""

Lianmin Zheng's avatar
Lianmin Zheng committed
3
4
5
import asyncio
import logging
import time
Lianmin Zheng's avatar
Lianmin Zheng committed
6
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
7
from concurrent.futures import ThreadPoolExecutor
8
from typing import List
Lianmin Zheng's avatar
Lianmin Zheng committed
9
10
11
12

import rpyc
import torch
from rpyc.utils.classic import obtain
Liangsheng Yin's avatar
Liangsheng Yin committed
13

Liangsheng Yin's avatar
Liangsheng Yin committed
14
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
15
from sglang.srt.constrained.fsm_cache import FSMCache
16
from sglang.srt.constrained.jump_forward import JumpForwardCache
Lianmin Zheng's avatar
Lianmin Zheng committed
17
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
Liangsheng Yin's avatar
Liangsheng Yin committed
18
from sglang.srt.managers.io_struct import (
19
    AbortReq,
Liangsheng Yin's avatar
Liangsheng Yin committed
20
21
    BatchTokenIDOut,
    FlushCacheReq,
22
    TokenizedGenerateReqInput,
Liangsheng Yin's avatar
Liangsheng Yin committed
23
)
24
from sglang.srt.managers.controller.infer_batch import BaseFinishReason, Batch, FINISH_ABORT, ForwardMode, Req
25
26
27
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
Lianmin Zheng's avatar
Lianmin Zheng committed
28
from sglang.srt.model_config import ModelConfig
29
from sglang.srt.server_args import ModelPortArgs, ServerArgs
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31
32
33
from sglang.srt.utils import (
    get_int_token_logit_bias,
    is_multimodal_model,
    set_random_seed,
34
35
    start_rpyc_process,
    suppress_other_loggers,
Lianmin Zheng's avatar
Lianmin Zheng committed
36
)
37
38
from sglang.utils import get_exception_traceback

39
logger = logging.getLogger("srt.tp_worker")
Lianmin Zheng's avatar
Lianmin Zheng committed
40
41


42
class ModelTpServer:
43
    def __init__(
Lianmin Zheng's avatar
Lianmin Zheng committed
44
        self,
45
        gpu_id: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
46
47
        tp_rank: int,
        server_args: ServerArgs,
48
49
        model_port_args: ModelPortArgs,
        model_overide_args,
Lianmin Zheng's avatar
Lianmin Zheng committed
50
    ):
51
52
        server_args, model_port_args = obtain(server_args), obtain(model_port_args)
        suppress_other_loggers()
Lianmin Zheng's avatar
Lianmin Zheng committed
53
54

        # Copy arguments
55
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
56
57
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
58
        self.dp_size = server_args.dp_size
Lianmin Zheng's avatar
Lianmin Zheng committed
59
        self.schedule_heuristic = server_args.schedule_heuristic
60
        self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
Lianmin Zheng's avatar
Lianmin Zheng committed
61
62
63

        # Init model and tokenizer
        self.model_config = ModelConfig(
Liangsheng Yin's avatar
Liangsheng Yin committed
64
65
66
            server_args.model_path,
            server_args.trust_remote_code,
            context_length=server_args.context_length,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
67
            model_overide_args=model_overide_args,
Lianmin Zheng's avatar
Lianmin Zheng committed
68
69
        )
        self.model_runner = ModelRunner(
Liangsheng Yin's avatar
Liangsheng Yin committed
70
71
            model_config=self.model_config,
            mem_fraction_static=server_args.mem_fraction_static,
72
            gpu_id=gpu_id,
Liangsheng Yin's avatar
Liangsheng Yin committed
73
74
            tp_rank=tp_rank,
            tp_size=server_args.tp_size,
75
            nccl_port=model_port_args.nccl_port,
Lianmin Zheng's avatar
Lianmin Zheng committed
76
            server_args=server_args,
Lianmin Zheng's avatar
Lianmin Zheng committed
77
        )
78

Lianmin Zheng's avatar
Lianmin Zheng committed
79
80
81
82
83
84
85
86
87
88
89
90
91
        if is_multimodal_model(server_args.model_path):
            self.processor = get_processor(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
            self.tokenizer = self.processor.tokenizer
        else:
            self.tokenizer = get_tokenizer(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
92
93
        self.max_total_num_tokens = self.model_runner.max_total_num_tokens
        self.max_prefill_tokens = max(
94
            self.model_config.context_len,
95
            (
96
                min(self.max_total_num_tokens // 6, 65536)
97
98
                if server_args.max_prefill_tokens is None
                else server_args.max_prefill_tokens
99
            ),
Lianmin Zheng's avatar
Lianmin Zheng committed
100
        )
101
102
        self.max_running_requests = (self.max_total_num_tokens // 2
            if server_args.max_running_requests is None else server_args.max_running_requests)
Lianmin Zheng's avatar
Lianmin Zheng committed
103
104
105
        self.int_token_logit_bias = torch.tensor(
            get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
        )
106
        set_random_seed(server_args.random_seed)
107
108

        # Print info
109
110
        logger.info(
            f"[gpu_id={self.gpu_id}] "
111
112
            f"max_total_num_tokens={self.max_total_num_tokens}, "
            f"max_prefill_tokens={self.max_prefill_tokens}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
113
114
            f"context_len={self.model_config.context_len}, "
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
115
        if self.tp_rank == 0:
116
117
118
119
            logger.info(
                f"[gpu_id={self.gpu_id}] "
                f"server_args: {server_args.print_mode_args()}"
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
120
121

        # Init cache
Liangsheng Yin's avatar
Liangsheng Yin committed
122
123
124
125
126
        self.tree_cache = RadixCache(
            req_to_token_pool=self.model_runner.req_to_token_pool,
            token_to_kv_pool=self.model_runner.token_to_kv_pool,
            disable=server_args.disable_radix_cache,
        )
Cody Yu's avatar
Cody Yu committed
127
        self.tree_cache_metrics = {"total": 0, "hit": 0}
128
        self.scheduler = ScheduleHeuristic(
Lianmin Zheng's avatar
Lianmin Zheng committed
129
            self.schedule_heuristic,
130
131
132
            self.max_running_requests,
            self.max_prefill_tokens,
            self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
133
134
135
136
137
138
139
140
141
142
            self.tree_cache,
        )
        self.req_to_token_pool = self.model_runner.req_to_token_pool
        self.token_to_kv_pool = self.model_runner.token_to_kv_pool

        # Init running status
        self.forward_queue: List[Req] = []
        self.running_batch: Batch = None
        self.out_pyobjs = []
        self.decode_forward_ct = 0
143
        self.stream_interval = server_args.stream_interval
Lianmin Zheng's avatar
Lianmin Zheng committed
144
145
        self.num_generated_tokens = 0
        self.last_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
146
147

        # Init the FSM cache for constrained generation
148
149
150
151
152
153
154
        self.regex_fsm_cache = FSMCache(
            server_args.tokenizer_path,
            {
                "tokenizer_mode": server_args.tokenizer_mode,
                "trust_remote_code": server_args.trust_remote_code,
            },
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
155
        self.jump_forward_cache = JumpForwardCache()
Lianmin Zheng's avatar
Lianmin Zheng committed
156

157
        # Init new token estimation
Liangsheng Yin's avatar
Liangsheng Yin committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
        self.new_token_ratio = min(
            global_config.base_new_token_ratio * server_args.schedule_conservativeness,
            1.0,
        )
        self.min_new_token_ratio = min(
            global_config.base_min_new_token_ratio
            * server_args.schedule_conservativeness,
            1.0,
        )
        self.new_token_ratio_decay = global_config.new_token_ratio_decay
        self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
172

Lianmin Zheng's avatar
Lianmin Zheng committed
173
    def exposed_step(self, recv_reqs):
174
        if self.tp_size * self.dp_size != 1:
Lianmin Zheng's avatar
Lianmin Zheng committed
175
176
177
178
179
180
181
            recv_reqs = obtain(recv_reqs)

        try:
            # Recv requests
            for recv_req in recv_reqs:
                if isinstance(recv_req, TokenizedGenerateReqInput):
                    self.handle_generate_request(recv_req)
Liangsheng Yin's avatar
Liangsheng Yin committed
182
183
                elif isinstance(recv_req, FlushCacheReq):
                    self.flush_cache()
184
185
                elif isinstance(recv_req, AbortReq):
                    self.abort_request(recv_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
186
187
188
189
190
191
                else:
                    raise ValueError(f"Invalid request: {recv_req}")

            # Forward
            self.forward_step()
        except Exception:
192
193
            logger.error("Exception in ModelTpServer:\n" + get_exception_traceback())
            raise
Lianmin Zheng's avatar
Lianmin Zheng committed
194
195
196
197
198
199
200
201
202
203
204

        # Return results
        ret = self.out_pyobjs
        self.out_pyobjs = []
        return ret

    @torch.inference_mode()
    def forward_step(self):
        new_batch = self.get_new_fill_batch()

        if new_batch is not None:
205
            # Run a new fill batch
Lianmin Zheng's avatar
Lianmin Zheng committed
206
            self.forward_fill_batch(new_batch)
Liangsheng Yin's avatar
Liangsheng Yin committed
207
208
            self.cache_filled_batch(new_batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
209
210
211
212
213
214
215
216
217
218
            if not new_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = new_batch
                else:
                    self.running_batch.merge(new_batch)
        else:
            # Run decode batch
            if self.running_batch is not None:
                # Run a few decode batches continuously for reducing overhead
                for _ in range(10):
Lianmin Zheng's avatar
Lianmin Zheng committed
219
                    self.num_generated_tokens += len(self.running_batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
220
221
                    self.forward_decode_batch(self.running_batch)

222
223
                    # Print stats
                    if self.tp_rank == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
224
                        if self.decode_forward_ct % 40 == 0:
225
                            num_used = self.max_total_num_tokens - (
Lianmin Zheng's avatar
Lianmin Zheng committed
226
227
228
                                self.token_to_kv_pool.available_size()
                                + self.tree_cache.evictable_size()
                            )
229
                            throughput = self.num_generated_tokens / (
Liangsheng Yin's avatar
Liangsheng Yin committed
230
231
                                time.time() - self.last_stats_tic
                            )
Lianmin Zheng's avatar
Lianmin Zheng committed
232
233
                            self.num_generated_tokens = 0
                            self.last_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
234
                            logger.info(
235
                                f"[gpu_id={self.gpu_id}] Decode batch. "
Lianmin Zheng's avatar
Lianmin Zheng committed
236
237
                                f"#running-req: {len(self.running_batch.reqs)}, "
                                f"#token: {num_used}, "
238
                                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
239
                                f"gen throughput (token/s): {throughput:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
240
241
                                f"#queue-req: {len(self.forward_queue)}"
                            )
242
243
244
245
246
247
248

                    if self.running_batch.is_empty():
                        self.running_batch = None
                        break

                    if self.out_pyobjs and self.running_batch.reqs[0].stream:
                        break
249
            else:
250
                # Check the available size
251
252
253
254
                available_size = (
                    self.token_to_kv_pool.available_size()
                    + self.tree_cache.evictable_size()
                )
255
                if available_size != self.max_total_num_tokens:
Ying Sheng's avatar
Ying Sheng committed
256
                    warnings.warn(
257
                        "Warning: "
258
                        f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
259
260
                        "KV cache pool leak detected!"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
261
262
263
264
265

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
266
        req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
267
268
        req.pixel_values = recv_req.pixel_values
        if req.pixel_values is not None:
269
            req.pad_value = [
Lianmin Zheng's avatar
Lianmin Zheng committed
270
271
272
273
274
                (recv_req.image_hash) % self.model_config.vocab_size,
                (recv_req.image_hash >> 16) % self.model_config.vocab_size,
                (recv_req.image_hash >> 32) % self.model_config.vocab_size,
                (recv_req.image_hash >> 64) % self.model_config.vocab_size,
            ]
Lianmin Zheng's avatar
Lianmin Zheng committed
275
            req.image_size = recv_req.image_size
Liangsheng Yin's avatar
Liangsheng Yin committed
276
277
278
279
280
281
282
            req.origin_input_ids, req.image_offset = (
                self.model_runner.model.pad_input_ids(
                    req.origin_input_ids_unpadded,
                    req.pad_value,
                    req.pixel_values.shape,
                    req.image_size,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
283
284
            )
        req.sampling_params = recv_req.sampling_params
285
286
        req.return_logprob = recv_req.return_logprob
        req.logprob_start_len = recv_req.logprob_start_len
Liangsheng Yin's avatar
Liangsheng Yin committed
287
        req.top_logprobs_num = recv_req.top_logprobs_num
Lianmin Zheng's avatar
Lianmin Zheng committed
288
289
290
        req.stream = recv_req.stream
        req.tokenizer = self.tokenizer

291
292
        # Init regex fsm
        if req.sampling_params.regex is not None:
Cody Yu's avatar
Cody Yu committed
293
            req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
294
            if not self.disable_regex_jump_forward:
Liangsheng Yin's avatar
Liangsheng Yin committed
295
                req.jump_forward_map = self.jump_forward_cache.query(
Liangsheng Yin's avatar
Liangsheng Yin committed
296
297
                    req.sampling_params.regex
                )
298

299
        # Truncate prompts that are too long
Liangsheng Yin's avatar
Liangsheng Yin committed
300
        req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
Lianmin Zheng's avatar
Lianmin Zheng committed
301
302
        req.sampling_params.max_new_tokens = min(
            req.sampling_params.max_new_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
303
            self.model_config.context_len - 1 - len(req.origin_input_ids),
304
            self.max_total_num_tokens - 128 - len(req.origin_input_ids),
Lianmin Zheng's avatar
Lianmin Zheng committed
305
306
307
308
309
310
        )
        self.forward_queue.append(req)

    def get_new_fill_batch(self):
        if (
            self.running_batch is not None
311
            and len(self.running_batch.reqs) > self.max_running_requests
Lianmin Zheng's avatar
Lianmin Zheng committed
312
313
314
        ):
            return None

315
        # Compute matched prefix length
Lianmin Zheng's avatar
Lianmin Zheng committed
316
        for req in self.forward_queue:
Liangsheng Yin's avatar
Liangsheng Yin committed
317
318
319
320
            assert (
                len(req.output_ids) == 0
            ), "The output ids should be empty when prefilling"
            req.input_ids = req.origin_input_ids + req.prev_output_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
321
            prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
322
323
324
            if req.return_logprob:
                prefix_indices = prefix_indices[: req.logprob_start_len]
            req.extend_input_len = len(req.input_ids) - len(prefix_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
            req.prefix_indices = prefix_indices
            req.last_node = last_node

        # Get priority queue
        self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)

        # Add requests if there is available space
        can_run_list = []
        new_batch_total_tokens = 0
        new_batch_input_tokens = 0

        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if self.running_batch:
            available_size -= sum(
                [
342
                    (r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
343
344
345
346
347
                    for r in self.running_batch.reqs
                ]
            )

        for req in self.forward_queue:
Liangsheng Yin's avatar
Liangsheng Yin committed
348
            if req.return_logprob and req.normalized_prompt_logprob is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
349
                # Need at least two tokens to compute normalized logprob
350
351
352
                if req.extend_input_len < 2:
                    delta = 2 - req.extend_input_len
                    req.extend_input_len += delta
Lianmin Zheng's avatar
Lianmin Zheng committed
353
354
355
                    req.prefix_indices = req.prefix_indices[:-delta]
                    if req.image_offset is not None:
                        req.image_offset += delta
356
            if req.extend_input_len == 0 and req.max_new_tokens() > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
357
                # Need at least one token to compute logits
358
                req.extend_input_len = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
359
360
361
362
363
                req.prefix_indices = req.prefix_indices[:-1]
                if req.image_offset is not None:
                    req.image_offset += 1

            if (
364
                req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
365
                < available_size
366
                and req.extend_input_len + new_batch_input_tokens
367
                < self.max_prefill_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
368
            ):
Liangsheng Yin's avatar
Liangsheng Yin committed
369
                delta = self.tree_cache.inc_lock_ref(req.last_node)
Lianmin Zheng's avatar
Lianmin Zheng committed
370
371
372
                available_size += delta

                if not (
373
                    req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
374
375
                    < available_size
                ):
Liangsheng Yin's avatar
Liangsheng Yin committed
376
377
                    # Undo locking
                    delta = self.tree_cache.dec_lock_ref(req.last_node)
Lianmin Zheng's avatar
Lianmin Zheng committed
378
                    available_size += delta
379
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
380
                else:
381
                    # Add this request to the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
382
383
                    can_run_list.append(req)
                    new_batch_total_tokens += (
384
                        req.extend_input_len + req.max_new_tokens()
Lianmin Zheng's avatar
Lianmin Zheng committed
385
                    )
386
                    new_batch_input_tokens += req.extend_input_len
387
388
            else:
                break
Lianmin Zheng's avatar
Lianmin Zheng committed
389
390
391
        if len(can_run_list) == 0:
            return None

392
        # Print stats
Lianmin Zheng's avatar
Lianmin Zheng committed
393
        if self.tp_rank == 0:
394
395
396
            running_req = (
                0 if self.running_batch is None else len(self.running_batch.reqs)
            )
Cody Yu's avatar
Cody Yu committed
397
            hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
398
399
400
            self.tree_cache_metrics["total"] += (
                hit_tokens + new_batch_input_tokens
            ) / 10**9
Cody Yu's avatar
Cody Yu committed
401
402
403
404
            self.tree_cache_metrics["hit"] += hit_tokens / 10**9
            tree_cache_hit_rate = (
                self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
405
            logger.info(
406
407
408
409
410
411
412
                f"[gpu_id={self.gpu_id}] Prefil batch. "
                f"#new-seq: {len(can_run_list)}, "
                f"#new-token: {new_batch_input_tokens}, "
                f"#cached-token: {hit_tokens}, "
                f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
                f"#running-req: {running_req}, "
                f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
Cody Yu's avatar
Cody Yu committed
413
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
414
            # logger.debug(
415
416
417
418
            #    f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
            #    f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
            #    f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
            #    f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
Liangsheng Yin's avatar
Liangsheng Yin committed
419
            # )
Lianmin Zheng's avatar
Lianmin Zheng committed
420

421
        # Return the new batch
422
        new_batch = Batch.init_new(
Lianmin Zheng's avatar
Lianmin Zheng committed
423
424
425
426
427
428
429
430
431
432
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
        )
        self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
        return new_batch

    def forward_fill_batch(self, batch: Batch):
        # Build batch tensors
433
434
435
436
        batch.prepare_for_extend(
            self.model_config.vocab_size, self.int_token_logit_bias
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
437
438
        if batch.extend_num_tokens != 0:
            # Forward
439
            logits, (
Liangsheng Yin's avatar
Liangsheng Yin committed
440
                prefill_token_logprobs,
441
                normalized_prompt_logprobs,
Liangsheng Yin's avatar
Liangsheng Yin committed
442
443
                prefill_top_logprobs,
                decode_top_logprobs,
444
                last_logprobs,
445
            ) = self.model_runner.forward(batch, ForwardMode.EXTEND)
Liangsheng Yin's avatar
Liangsheng Yin committed
446
            if prefill_token_logprobs is not None:
447
448
                prefill_token_logprobs = prefill_token_logprobs.tolist()
                normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
449

Cody Yu's avatar
Cody Yu committed
450
            next_token_ids, _ = batch.sample(logits)
451
452
453

            # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
            if last_logprobs is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
454
455
456
457
                last_token_logprobs = last_logprobs[
                    torch.arange(len(batch.reqs), device=next_token_ids.device),
                    next_token_ids,
                ].tolist()
458
459

            next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
460
461
        else:
            next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
Cody Yu's avatar
Cody Yu committed
462
463

        # Check finish condition
464
        pt = 0
465
        for i, req in enumerate(batch.reqs):
466
            req.completion_tokens_wo_jump_forward += 1
467
468
            req.output_ids = [next_token_ids[i]]
            req.check_finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
469

470
            if req.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
471
472
473
474
475
476
477
478
479
480
                if req.normalized_prompt_logprob is None:
                    req.normalized_prompt_logprob = normalized_prompt_logprobs[i]

                if req.prefill_token_logprobs is None:
                    # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
                    req.prefill_token_logprobs = list(
                        zip(
                            prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
                            req.input_ids[-req.extend_input_len + 1 :],
                        )
Liangsheng Yin's avatar
Liangsheng Yin committed
481
                    )
Liangsheng Yin's avatar
Liangsheng Yin committed
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
                    if req.logprob_start_len == 0:
                        req.prefill_token_logprobs = [
                            (None, req.input_ids[0])
                        ] + req.prefill_token_logprobs

                if req.last_update_decode_tokens != 0:
                    req.decode_token_logprobs.extend(
                        list(
                            zip(
                                prefill_token_logprobs[
                                    pt
                                    + req.extend_input_len
                                    - req.last_update_decode_tokens : pt
                                    + req.extend_input_len
                                    - 1
                                ],
                                req.input_ids[-req.last_update_decode_tokens + 1 :],
                            )
                        )
                    )

                req.decode_token_logprobs.append(
Liangsheng Yin's avatar
Liangsheng Yin committed
504
                    (last_token_logprobs[i], next_token_ids[i])
Liangsheng Yin's avatar
Liangsheng Yin committed
505
                )
506
507

            if req.top_logprobs_num > 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
508
509
510
511
512
513
514
515
516
517
                if req.prefill_top_logprobs is None:
                    req.prefill_top_logprobs = prefill_top_logprobs[i]
                    if req.logprob_start_len == 0:
                        req.prefill_top_logprobs = [None] + req.prefill_top_logprobs

                if req.last_update_decode_tokens != 0:
                    req.decode_top_logprobs.extend(
                        prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
                    )
                req.decode_top_logprobs.append(decode_top_logprobs[i])
518
519

            pt += req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
520
521
522

        self.handle_finished_requests(batch)

Liangsheng Yin's avatar
Liangsheng Yin committed
523
    def cache_filled_batch(self, batch: Batch):
524
        req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
Liangsheng Yin's avatar
Liangsheng Yin committed
525
526
527
528
529
530
531
532
533
534
        for i, req in enumerate(batch.reqs):
            new_prefix_indices, new_last_node = self.tree_cache.cache_req(
                token_ids=tuple(req.input_ids + req.output_ids)[:-1],
                last_uncached_pos=len(req.prefix_indices),
                req_pool_idx=req_pool_indices_cpu[i],
                del_in_memory_pool=False,
                old_last_node=req.last_node,
            )
            req.prefix_indices, req.last_node = new_prefix_indices, new_last_node

Lianmin Zheng's avatar
Lianmin Zheng committed
535
    def forward_decode_batch(self, batch: Batch):
536
537
538
        # check if decode out of memory
        if not batch.check_decode_mem():
            old_ratio = self.new_token_ratio
Liangsheng Yin's avatar
Liangsheng Yin committed
539
            self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
540
541
542
543
544
545
546
547
548
549

            retracted_reqs = batch.retract_decode()
            logger.info(
                "decode out of memory happened, "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.forward_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
Liangsheng Yin's avatar
Liangsheng Yin committed
550
                self.new_token_ratio - self.new_token_ratio_decay,
551
552
553
                self.min_new_token_ratio,
            )

554
        if not self.disable_regex_jump_forward:
Liangsheng Yin's avatar
Liangsheng Yin committed
555
            # check for jump-forward
Liangsheng Yin's avatar
Liangsheng Yin committed
556
            jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
557

Liangsheng Yin's avatar
Liangsheng Yin committed
558
            self.forward_queue.extend(jump_forward_reqs)
Liangsheng Yin's avatar
Liangsheng Yin committed
559
560
561
            if batch.is_empty():
                return

Lianmin Zheng's avatar
Lianmin Zheng committed
562
        # Update batch tensors
563
        self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
564
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
565
566

        # Forward
567
568
569
570
        logits, (
            _,
            _,
            _,
571
            decode_top_logprobs,
572
573
            last_logprobs,
        ) = self.model_runner.forward(batch, ForwardMode.DECODE)
Cody Yu's avatar
Cody Yu committed
574
        next_token_ids, _ = batch.sample(logits)
575
        next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
576

Cody Yu's avatar
Cody Yu committed
577
578
        # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
        if last_logprobs is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
579
            new_token_logprobs = last_logprobs[
580
                torch.arange(len(batch.reqs)), next_token_ids
581
            ].tolist()
Cody Yu's avatar
Cody Yu committed
582
583

        # Check finish condition
584
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
585
            req.completion_tokens_wo_jump_forward += 1
Liangsheng Yin's avatar
Liangsheng Yin committed
586
            req.output_ids.append(next_token_id)
Cody Yu's avatar
Cody Yu committed
587
588
            req.check_finished()

589
            if req.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
590
                req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
591
592

            if req.top_logprobs_num > 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
593
                req.decode_top_logprobs.append(decode_top_logprobs[i])
Lianmin Zheng's avatar
Lianmin Zheng committed
594
595
596
597
598

        self.handle_finished_requests(batch)

    def handle_finished_requests(self, batch: Batch):
        output_rids = []
Liangsheng Yin's avatar
Liangsheng Yin committed
599
        prev_output_strs = []
Lianmin Zheng's avatar
Lianmin Zheng committed
600
601
        output_tokens = []
        output_skip_special_tokens = []
602
        output_spaces_between_special_tokens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
603
        output_meta_info = []
604
        output_finished_reason: List[BaseFinishReason] = []
Lianmin Zheng's avatar
Lianmin Zheng committed
605
606
607
        finished_indices = []
        unfinished_indices = []
        for i, req in enumerate(batch.reqs):
608
            if req.finished():
Lianmin Zheng's avatar
Lianmin Zheng committed
609
610
611
612
                finished_indices.append(i)
            else:
                unfinished_indices.append(i)

613
            if req.finished() or (
614
615
616
617
618
619
620
                (
                    req.stream
                    and (
                        self.decode_forward_ct % self.stream_interval == 0
                        or len(req.output_ids) == 1
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
621
622
            ):
                output_rids.append(req.rid)
Liangsheng Yin's avatar
Liangsheng Yin committed
623
                prev_output_strs.append(req.prev_output_str)
Lianmin Zheng's avatar
Lianmin Zheng committed
624
625
626
627
                output_tokens.append(req.output_ids)
                output_skip_special_tokens.append(
                    req.sampling_params.skip_special_tokens
                )
628
629
630
                output_spaces_between_special_tokens.append(
                    req.sampling_params.spaces_between_special_tokens
                )
631

Lianmin Zheng's avatar
Lianmin Zheng committed
632
                meta_info = {
Liangsheng Yin's avatar
Liangsheng Yin committed
633
634
                    "prompt_tokens": len(req.origin_input_ids),
                    "completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
Liangsheng Yin's avatar
Liangsheng Yin committed
635
                    "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
636
                    "finish_reason": str(req.finished_reason),
Lianmin Zheng's avatar
Lianmin Zheng committed
637
                }
638
                if req.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
639
640
641
642
643
644
645
646
647
648
649
650
651
                    (
                        meta_info["prefill_token_logprobs"],
                        meta_info["decode_token_logprobs"],
                        meta_info["prefill_top_logprobs"],
                        meta_info["decode_top_logprobs"],
                        meta_info["normalized_prompt_logprob"],
                    ) = (
                        req.prefill_token_logprobs,
                        req.decode_token_logprobs,
                        req.prefill_top_logprobs,
                        req.decode_top_logprobs,
                        req.normalized_prompt_logprob,
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
652
                output_meta_info.append(meta_info)
653
                output_finished_reason.append(req.finished_reason)
Lianmin Zheng's avatar
Lianmin Zheng committed
654
655
656
657
658
659

        # Send to detokenizer
        if output_rids:
            self.out_pyobjs.append(
                BatchTokenIDOut(
                    output_rids,
Liangsheng Yin's avatar
Liangsheng Yin committed
660
                    prev_output_strs,
Lianmin Zheng's avatar
Lianmin Zheng committed
661
662
                    output_tokens,
                    output_skip_special_tokens,
663
                    output_spaces_between_special_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
664
                    output_meta_info,
665
                    output_finished_reason,
Lianmin Zheng's avatar
Lianmin Zheng committed
666
667
668
669
670
671
                )
            )

        # Remove finished reqs
        if finished_indices:
            # Update radix cache
672
            req_pool_indices_cpu = batch.req_pool_indices.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
673
674
            for i in finished_indices:
                req = batch.reqs[i]
Liangsheng Yin's avatar
Liangsheng Yin committed
675
676
677
678
                self.tree_cache.cache_req(
                    token_ids=tuple(req.input_ids + req.output_ids)[:-1],
                    last_uncached_pos=len(req.prefix_indices),
                    req_pool_idx=req_pool_indices_cpu[i],
679
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
680

Liangsheng Yin's avatar
Liangsheng Yin committed
681
                self.tree_cache.dec_lock_ref(req.last_node)
Lianmin Zheng's avatar
Lianmin Zheng committed
682
683
684
685
686
687
688

            # Update batch tensors
            if unfinished_indices:
                batch.filter_batch(unfinished_indices)
            else:
                batch.reqs = []

689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
    def flush_cache(self):
        if len(self.forward_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
            self.regex_fsm_cache.reset()
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
        else:
            warnings.warn(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.forward_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )

    def abort_request(self, recv_req):
708
        # Delete requests in the waiting queue
709
710
711
712
713
714
715
716
717
        to_del = None
        for i, req in enumerate(self.forward_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.forward_queue[to_del]

718
719
720
721
        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
                if req.rid == recv_req.rid:
722
                    req.finished_reason = FINISH_ABORT()
723
724
                    break

Lianmin Zheng's avatar
Lianmin Zheng committed
725

726
727
class ModelTpService(rpyc.Service):
    exposed_ModelTpServer = ModelTpServer
728
729


730
class ModelTpClient:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
731
    def __init__(
732
733
734
735
736
        self,
        gpu_ids: List[int],
        server_args: ServerArgs,
        model_port_args: ModelPortArgs,
        model_overide_args,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
737
    ):
738
739
        server_args, model_port_args = obtain(server_args), obtain(model_port_args)
        self.tp_size = server_args.tp_size
Lianmin Zheng's avatar
Lianmin Zheng committed
740

741
        if self.tp_size * server_args.dp_size == 1:
Lianmin Zheng's avatar
Lianmin Zheng committed
742
            # Init model
743
744
745
746
747
748
749
            assert len(gpu_ids) == 1
            self.model_server = ModelTpService().exposed_ModelTpServer(
                0,
                gpu_ids[0],
                server_args,
                model_port_args,
                model_overide_args,
750
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
751
752
753
754
755
756
757
758
759
760

            # Wrap functions
            def async_wrap(f):
                async def _func(*args, **kwargs):
                    return f(*args, **kwargs)

                return _func

            self.step = async_wrap(self.model_server.exposed_step)
        else:
761
            with ThreadPoolExecutor(self.tp_size) as executor:
Lianmin Zheng's avatar
Lianmin Zheng committed
762
                # Launch model processes
763
764
765
766
767
                rets = executor.map(
                    lambda args: start_rpyc_process(*args),
                    [(ModelTpService, p) for p in model_port_args.model_tp_ports],
                )
                self.model_services = [x[0] for x in rets]
Lianmin Zheng's avatar
Lianmin Zheng committed
768
769
770
771
                self.procs = [x[1] for x in rets]

                # Init model
                def init_model(i):
772
773
774
775
776
777
                    return self.model_services[i].ModelTpServer(
                        gpu_ids[i],
                        i,
                        server_args,
                        model_port_args,
                        model_overide_args,
778
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
779

780
                self.model_servers = executor.map(init_model, range(self.tp_size))
Lianmin Zheng's avatar
Lianmin Zheng committed
781
782
783
784
785
786
787
788
789
790
791
792

            # Wrap functions
            def async_wrap(func_name):
                fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]

                async def _func(*args, **kwargs):
                    tasks = [f(*args, **kwargs) for f in fs]
                    await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
                    return obtain(tasks[0].value)

                return _func

793
            self.step = async_wrap("step")