"vscode:/vscode.git/clone" did not exist on "1027d22ff62be6129fc119e820eaef47613c311a"
io_struct.py 47.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""
15
The definition of objects transferred between different
16
processes (TokenizerManager, DetokenizerManager, Scheduler).
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
"""

19
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import uuid
21
from abc import ABC
YAMY's avatar
YAMY committed
22
from dataclasses import dataclass, field
23
from enum import Enum
24
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
25

26
from sglang.srt.lora.lora_registry import LoRARef
27
from sglang.srt.managers.schedule_batch import BaseFinishReason
28
from sglang.srt.multimodal.mm_utils import has_valid_data
29
from sglang.srt.sampling.sampling_params import SamplingParams
30
from sglang.srt.utils import ImageData
31

32
# Handle serialization of Image for pydantic
33
34
35
36
if TYPE_CHECKING:
    from PIL.Image import Image
else:
    Image = Any
37

Lianmin Zheng's avatar
Lianmin Zheng committed
38

39
# Parameters for a session
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@dataclass
class BaseReq(ABC):
    rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)

    def regenerate_rid(self):
        """Generate a new request ID and return it."""
        if isinstance(self.rid, list):
            self.rid = [uuid.uuid4().hex for _ in range(len(self.rid))]
        else:
            self.rid = uuid.uuid4().hex
        return self.rid


@dataclass
class BaseBatchReq(ABC):
    rids: Optional[List[str]] = field(default=None, kw_only=True)

    def regenerate_rids(self):
        """Generate new request IDs and return them."""
        self.rids = [uuid.uuid4().hex for _ in range(len(self.rids))]
        return self.rids


63
64
65
66
67
@dataclass
class SessionParams:
    id: Optional[str] = None
    offset: Optional[int] = None
    replace: Optional[bool] = None
68
    drop_previous_output: Optional[bool] = None
69
70


71
72
# Type definitions for multimodal input data
# Individual data item types for each modality
73
ImageDataInputItem = Union[Image, str, ImageData, Dict]
74
75
76
77
78
79
80
81
82
83
84
85
AudioDataInputItem = Union[str, Dict]
VideoDataInputItem = Union[str, Dict]
# Union type for any multimodal data item
MultimodalDataInputItem = Union[
    ImageDataInputItem, VideoDataInputItem, AudioDataInputItem
]
# Format types supporting single items, lists, or nested lists for batch processing
MultimodalDataInputFormat = Union[
    List[List[MultimodalDataInputItem]],
    List[MultimodalDataInputItem],
    MultimodalDataInputItem,
]
86
87


Lianmin Zheng's avatar
Lianmin Zheng committed
88
@dataclass
89
class GenerateReqInput(BaseReq):
Ying Sheng's avatar
Ying Sheng committed
90
    # The input prompt. It can be a single prompt or a batch of prompts.
91
    text: Optional[Union[List[str], str]] = None
Rin Intachuen's avatar
Rin Intachuen committed
92
    # The token ids for text; one can specify either text or input_ids
93
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Rin Intachuen's avatar
Rin Intachuen committed
94
95
    # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
96
97
98
99
100
101
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
102
    image_data: Optional[MultimodalDataInputFormat] = None
103
    # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
104
105
106
    video_data: Optional[MultimodalDataInputFormat] = None
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
    audio_data: Optional[MultimodalDataInputFormat] = None
107
    # The sampling_params. See descriptions below.
108
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
109
    # Whether to return logprobs.
110
    return_logprob: Optional[Union[List[bool], bool]] = None
111
    # If return logprobs, the start location in the prompt for returning logprobs.
112
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
113
    logprob_start_len: Optional[Union[List[int], int]] = None
114
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
115
    top_logprobs_num: Optional[Union[List[int], int]] = None
116
117
    # If return logprobs, the token ids to return logprob for.
    token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
118
    # Whether to detokenize tokens in text in the returned logprobs.
119
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
120
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
121
    stream: bool = False
122
123
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
124
125
    # Whether to return hidden states
    return_hidden_states: Union[List[bool], bool] = False
126

127
128
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
129
130
131
    # Session info for continual prompting
    session_params: Optional[Union[List[Dict], Dict]] = None

132
    # The path to the LoRA adaptors
133
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
134
135
    # The uid of LoRA adaptors, should be initialized by tokenizer manager
    lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
136

137
138
139
140
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
    custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
141

142
    # For disaggregated inference
143
    bootstrap_host: Optional[Union[List[str], str]] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
144
    bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
145
    bootstrap_room: Optional[Union[List[int], int]] = None
146
    bootstrap_pair_key: Optional[Union[List[str], str]] = None
147

148
149
150
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

151
152
153
    # For background responses (OpenAI responses API)
    background: bool = False

154
155
156
    # Conversation id used for tracking requests
    conversation_id: Optional[str] = None

157
158
159
    # Priority for the request
    priority: Optional[int] = None

160
161
162
163
164
    # Extra key for classifying the request (e.g. cache_salt)
    extra_key: Optional[Union[List[str], str]] = None

    # Whether to disallow logging for this request (e.g. due to ZDR)
    no_logs: bool = False
165

166
167
    # For custom metric labels
    custom_labels: Optional[Dict[str, str]] = None
168

169
170
171
    # (Internal) Whether to return bytes for image generation
    return_bytes: bool = False

172
    def contains_mm_input(self) -> bool:
173
174
175
176
177
        return (
            has_valid_data(self.image_data)
            or has_valid_data(self.video_data)
            or has_valid_data(self.audio_data)
        )
178

179
    def normalize_batch_and_arguments(self):
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        """
        Normalize the batch size and arguments for the request.

        This method resolves various input formats and ensures all parameters
        are properly formatted as either single values or batches depending on the input.
        It also handles parallel sampling expansion and sets default values for
        unspecified parameters.

        Raises:
            ValueError: If inputs are not properly specified (e.g., none or all of
                       text, input_ids, input_embeds are provided)
        """
        self._validate_inputs()
        self._determine_batch_size()
        self._handle_parallel_sampling()

        if self.is_single:
            self._normalize_single_inputs()
        else:
            self._normalize_batch_inputs()

    def _validate_inputs(self):
        """Validate that the input configuration is valid."""
Rin Intachuen's avatar
Rin Intachuen committed
203
204
205
206
207
208
        if (
            self.text is None and self.input_ids is None and self.input_embeds is None
        ) or (
            self.text is not None
            and self.input_ids is not None
            and self.input_embeds is not None
209
        ):
Rin Intachuen's avatar
Rin Intachuen committed
210
211
212
            raise ValueError(
                "Either text, input_ids or input_embeds should be provided."
            )
213

214
215
    def _determine_batch_size(self):
        """Determine if this is a single example or a batch and the batch size."""
216
217
218
219
220
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
221
                self.is_single = False
222
                self.batch_size = len(self.text)
Rin Intachuen's avatar
Rin Intachuen committed
223
224
            self.input_embeds = None
        elif self.input_ids is not None:
Yinghai Lu's avatar
Yinghai Lu committed
225
226
            if len(self.input_ids) == 0:
                raise ValueError("input_ids cannot be empty.")
227
228
229
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
230
            else:
231
                self.is_single = False
232
                self.batch_size = len(self.input_ids)
Rin Intachuen's avatar
Rin Intachuen committed
233
234
235
236
237
238
            self.input_embeds = None
        else:
            if isinstance(self.input_embeds[0][0], float):
                self.is_single = True
                self.batch_size = 1
            else:
239
                self.is_single = False
Rin Intachuen's avatar
Rin Intachuen committed
240
                self.batch_size = len(self.input_embeds)
241

242
243
244
    def _handle_parallel_sampling(self):
        """Handle parallel sampling parameters and adjust batch size if needed."""
        # Determine parallel sample count
245
246
        if self.sampling_params is None:
            self.parallel_sample_num = 1
247
            return
248
        elif isinstance(self.sampling_params, dict):
249
250
251
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
252
253
254
255
256
            for sampling_params in self.sampling_params:
                if self.parallel_sample_num != sampling_params.get("n", 1):
                    raise ValueError(
                        "The parallel_sample_num should be the same for all samples in sample params."
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
257

258
        # If using parallel sampling with a single example, convert to batch
259
260
261
262
263
264
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]
265
266
            if self.input_embeds is not None:
                self.input_embeds = [self.input_embeds]
267

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    def _normalize_single_inputs(self):
        """Normalize inputs for a single example."""
        if self.sampling_params is None:
            self.sampling_params = {}
        if self.rid is None:
            self.rid = uuid.uuid4().hex
        if self.return_logprob is None:
            self.return_logprob = False
        if self.logprob_start_len is None:
            self.logprob_start_len = -1
        if self.top_logprobs_num is None:
            self.top_logprobs_num = 0
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = None

    def _normalize_batch_inputs(self):
        """Normalize inputs for a batch of examples, including parallel sampling expansion."""
        # Calculate expanded batch size
        if self.parallel_sample_num == 1:
            num = self.batch_size
Lianmin Zheng's avatar
Lianmin Zheng committed
288
        else:
289
290
291
292
293
            # Expand parallel_sample_num
            num = self.batch_size * self.parallel_sample_num

        # Expand input based on type
        self._expand_inputs(num)
294
        self._normalize_rid(num)
295
296
        self._normalize_lora_paths(num)
        self._normalize_image_data(num)
297
        self._normalize_video_data(num)
298
299
300
301
        self._normalize_audio_data(num)
        self._normalize_sampling_params(num)
        self._normalize_logprob_params(num)
        self._normalize_custom_logit_processor(num)
302
        self._normalize_bootstrap_params(num)
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329

    def _expand_inputs(self, num):
        """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
        if self.text is not None:
            if not isinstance(self.text, list):
                raise ValueError("Text should be a list for batch processing.")
            self.text = self.text * self.parallel_sample_num
        elif self.input_ids is not None:
            if not isinstance(self.input_ids, list) or not isinstance(
                self.input_ids[0], list
            ):
                raise ValueError(
                    "input_ids should be a list of lists for batch processing."
                )
            self.input_ids = self.input_ids * self.parallel_sample_num
        elif self.input_embeds is not None:
            if not isinstance(self.input_embeds, list):
                raise ValueError("input_embeds should be a list for batch processing.")
            self.input_embeds = self.input_embeds * self.parallel_sample_num

    def _normalize_lora_paths(self, num):
        """Normalize LoRA paths for batch processing."""
        if self.lora_path is not None:
            if isinstance(self.lora_path, str):
                self.lora_path = [self.lora_path] * num
            elif isinstance(self.lora_path, list):
                self.lora_path = self.lora_path * self.parallel_sample_num
330
            else:
331
332
333
334
335
336
337
338
339
340
341
                raise ValueError("lora_path should be a list or a string.")

    def _normalize_image_data(self, num):
        """Normalize image data for batch processing."""
        if self.image_data is None:
            self.image_data = [None] * num
        elif not isinstance(self.image_data, list):
            # Single image, convert to list of single-image lists
            self.image_data = [[self.image_data]] * num
            self.modalities = ["image"] * num
        elif isinstance(self.image_data, list):
342
343
344
345
346
            # Handle empty list case - treat as no images
            if len(self.image_data) == 0:
                self.image_data = [None] * num
                return

347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
            if len(self.image_data) != self.batch_size:
                raise ValueError(
                    "The length of image_data should be equal to the batch size."
                )

            self.modalities = []
            if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
                # Already a list of lists, keep as is
                for i in range(len(self.image_data)):
                    if self.image_data[i] is None or self.image_data[i] == [None]:
                        self.modalities.append(None)
                    elif len(self.image_data[i]) == 1:
                        self.modalities.append("image")
                    elif len(self.image_data[i]) > 1:
                        self.modalities.append("multi-images")
362
363
364
                    else:
                        # Ensure len(self.modalities) == len(self.image_data)
                        self.modalities.append(None)
365
                # Expand parallel_sample_num
366
367
                self.image_data = self.image_data * self.parallel_sample_num
                self.modalities = self.modalities * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
368
            else:
369
370
371
372
373
374
                # List of images for a batch, wrap each in a list
                wrapped_images = [[img] for img in self.image_data]
                # Expand for parallel sampling
                self.image_data = wrapped_images * self.parallel_sample_num
                self.modalities = ["image"] * num

375
376
377
378
379
380
381
382
383
    def _normalize_video_data(self, num):
        """Normalize video data for batch processing."""
        if self.video_data is None:
            self.video_data = [None] * num
        elif not isinstance(self.video_data, list):
            self.video_data = [self.video_data] * num
        elif isinstance(self.video_data, list):
            self.video_data = self.video_data * self.parallel_sample_num

384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
    def _normalize_audio_data(self, num):
        """Normalize audio data for batch processing."""
        if self.audio_data is None:
            self.audio_data = [None] * num
        elif not isinstance(self.audio_data, list):
            self.audio_data = [self.audio_data] * num
        elif isinstance(self.audio_data, list):
            self.audio_data = self.audio_data * self.parallel_sample_num

    def _normalize_sampling_params(self, num):
        """Normalize sampling parameters for batch processing."""
        if self.sampling_params is None:
            self.sampling_params = [{}] * num
        elif isinstance(self.sampling_params, dict):
            self.sampling_params = [self.sampling_params] * num
        else:  # Already a list
            self.sampling_params = self.sampling_params * self.parallel_sample_num

    def _normalize_rid(self, num):
        """Normalize request IDs for batch processing."""
        if self.rid is None:
            self.rid = [uuid.uuid4().hex for _ in range(num)]
406
407
408
409
        elif isinstance(self.rid, str):
            new_rids = [f"{self.rid}_{i}" for i in range(num)]
            self.rid = new_rids
        elif isinstance(self.rid, list):
410
411
412
            # Note: the length of rid shall be the same as the batch_size,
            # as the rid would be expanded for parallel sampling in tokenizer_manager
            if len(self.rid) != self.batch_size:
413
414
415
416
417
                raise ValueError(
                    "The specified rids length mismatch with the batch_size for batch processing."
                )
        else:
            raise ValueError("The rid should be a string or a list of strings.")
418
419
420
421
422
423
424
425
426
427

    def _normalize_logprob_params(self, num):
        """Normalize logprob-related parameters for batch processing."""

        # Helper function to normalize a parameter
        def normalize_param(param, default_value, param_name):
            if param is None:
                return [default_value] * num
            elif not isinstance(param, list):
                return [param] * num
428
            else:
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
                if self.parallel_sample_num > 1:
                    raise ValueError(
                        f"Cannot use list {param_name} with parallel_sample_num > 1"
                    )
                return param

        # Normalize each logprob parameter
        self.return_logprob = normalize_param(
            self.return_logprob, False, "return_logprob"
        )
        self.logprob_start_len = normalize_param(
            self.logprob_start_len, -1, "logprob_start_len"
        )
        self.top_logprobs_num = normalize_param(
            self.top_logprobs_num, 0, "top_logprobs_num"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
445

446
447
448
449
450
451
452
453
454
455
456
457
458
        # Handle token_ids_logprob specially due to its nested structure
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = [None] * num
        elif not isinstance(self.token_ids_logprob, list):
            self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
        elif not isinstance(self.token_ids_logprob[0], list):
            self.token_ids_logprob = [
                copy.deepcopy(self.token_ids_logprob) for _ in range(num)
            ]
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list token_ids_logprob with parallel_sample_num > 1"
            )
459

460
461
462
463
464
465
466
467
468
469
    def _normalize_custom_logit_processor(self, num):
        """Normalize custom logit processor for batch processing."""
        if self.custom_logit_processor is None:
            self.custom_logit_processor = [None] * num
        elif not isinstance(self.custom_logit_processor, list):
            self.custom_logit_processor = [self.custom_logit_processor] * num
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list custom_logit_processor with parallel_sample_num > 1"
            )
470

471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
    def _normalize_bootstrap_params(self, num):
        """Normalize bootstrap parameters for batch processing."""
        # Normalize bootstrap_host
        if self.bootstrap_host is None:
            self.bootstrap_host = [None] * num
        elif not isinstance(self.bootstrap_host, list):
            self.bootstrap_host = [self.bootstrap_host] * num
        elif isinstance(self.bootstrap_host, list):
            self.bootstrap_host = self.bootstrap_host * self.parallel_sample_num

        # Normalize bootstrap_port
        if self.bootstrap_port is None:
            self.bootstrap_port = [None] * num
        elif not isinstance(self.bootstrap_port, list):
            self.bootstrap_port = [self.bootstrap_port] * num
        elif isinstance(self.bootstrap_port, list):
            self.bootstrap_port = self.bootstrap_port * self.parallel_sample_num

        # Normalize bootstrap_room
        if self.bootstrap_room is None:
            self.bootstrap_room = [None] * num
        elif not isinstance(self.bootstrap_room, list):
            self.bootstrap_room = [self.bootstrap_room + i for i in range(num)]
        elif isinstance(self.bootstrap_room, list):
            self.bootstrap_room = self.bootstrap_room * self.parallel_sample_num

        # Normalize bootstrap_pair_key
        if self.bootstrap_pair_key is None:
            self.bootstrap_pair_key = [None] * num
        elif not isinstance(self.bootstrap_pair_key, list):
            self.bootstrap_pair_key = [self.bootstrap_pair_key] * num
        elif isinstance(self.bootstrap_pair_key, list):
            self.bootstrap_pair_key = self.bootstrap_pair_key * self.parallel_sample_num

505
506
    def _validate_session_params(self):
        """Validate that session parameters are properly formatted."""
507
        if self.session_params is not None:
508
            if not isinstance(self.session_params, dict) and not isinstance(
509
                self.session_params[0], dict
510
511
            ):
                raise ValueError("Session params must be a dict or a list of dicts.")
512

513
514
515
516
    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
517
518
519
            input_embeds=(
                self.input_embeds[i] if self.input_embeds is not None else None
            ),
520
            image_data=self.image_data[i],
521
            video_data=self.video_data[i],
Mick's avatar
Mick committed
522
            audio_data=self.audio_data[i],
523
524
525
526
527
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
528
            token_ids_logprob=self.token_ids_logprob[i],
529
530
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
531
            log_metrics=self.log_metrics,
532
533
534
535
536
            return_hidden_states=(
                self.return_hidden_states[i]
                if isinstance(self.return_hidden_states, list)
                else self.return_hidden_states
            ),
537
            modalities=self.modalities[i] if self.modalities else None,
538
            session_params=self.session_params,
539
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
540
            lora_id=self.lora_id[i] if self.lora_id is not None else None,
541
542
543
544
545
            custom_logit_processor=(
                self.custom_logit_processor[i]
                if self.custom_logit_processor is not None
                else None
            ),
546
            # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
547
548
549
            bootstrap_host=(
                self.bootstrap_host[i] if self.bootstrap_host is not None else None
            ),
550
551
552
            bootstrap_port=(
                self.bootstrap_port[i] if self.bootstrap_port is not None else None
            ),
553
554
555
            bootstrap_room=(
                self.bootstrap_room[i] if self.bootstrap_room is not None else None
            ),
556
557
558
559
560
            bootstrap_pair_key=(
                self.bootstrap_pair_key[i]
                if self.bootstrap_pair_key is not None
                else None
            ),
561
562
563
            data_parallel_rank=(
                self.data_parallel_rank if self.data_parallel_rank is not None else None
            ),
564
            conversation_id=self.conversation_id,
565
            priority=self.priority,
566
567
568
            extra_key=self.extra_key,
            no_logs=self.no_logs,
            custom_labels=self.custom_labels,
569
            return_bytes=self.return_bytes,
570
571
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
572
573

@dataclass
574
class TokenizedGenerateReqInput(BaseReq):
575
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
576
    input_text: str
577
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
578
    input_ids: List[int]
Mick's avatar
Mick committed
579
580
    # The multimodal inputs
    mm_inputs: dict
581
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
582
    sampling_params: SamplingParams
583
    # Whether to return the logprobs
584
    return_logprob: bool
585
    # If return logprobs, the start location in the prompt for returning logprobs.
586
    logprob_start_len: int
587
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
588
    top_logprobs_num: int
589
590
    # If return logprobs, the token id to return logprob for
    token_ids_logprob: List[int]
591
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
592
    stream: bool
593

594
595
    # Whether to return hidden states
    return_hidden_states: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
596

Rin Intachuen's avatar
Rin Intachuen committed
597
598
    # The input embeds
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
599

600
601
    # Session info for continual prompting
    session_params: Optional[SessionParams] = None
602

603
604
605
    # LoRA related
    lora_id: Optional[str] = None  # None means just use the base model

606
607
608
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
609
610
    custom_logit_processor: Optional[str] = None

611
612
    # For disaggregated inference
    bootstrap_host: Optional[str] = None
613
    bootstrap_port: Optional[int] = None
614
    bootstrap_room: Optional[int] = None
615
    bootstrap_pair_key: Optional[str] = None
616

617
618
619
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

620
621
622
    # Priority for the request
    priority: Optional[int] = None

623
624
625
    # Extra key for classifying the request (e.g. cache_salt)
    extra_key: Optional[str] = None

626
627
    # Whether to disallow logging for this request (e.g. due to ZDR)
    no_logs: bool = False
628

629
630
631
    # tracing context
    trace_context: Optional[Dict] = None

632
633
634
    # (Internal) Whether to return bytes for image generation
    return_bytes: bool = False

Lianmin Zheng's avatar
Lianmin Zheng committed
635

636
@dataclass
637
class BatchTokenizedGenerateReqInput(BaseBatchReq):
638
639
640
641
642
643
644
645
646
647
648
649
650
    # The batch of tokenized requests
    batch: List[TokenizedGenerateReqInput]

    def __len__(self):
        return len(self.batch)

    def __getitem__(self, i):
        return self.batch[i]

    def __iter__(self):
        return iter(self.batch)


651
@dataclass
652
class EmbeddingReqInput(BaseReq):
653
    # The input prompt. It can be a single prompt or a batch of prompts.
woodx's avatar
woodx committed
654
    text: Optional[Union[List[List[str]], List[str], str]] = None
655
656
657
658
659
660
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
661
    image_data: Optional[MultimodalDataInputFormat] = None
662
    # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
663
    video_data: Optional[MultimodalDataInputFormat] = None
664
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
665
    audio_data: Optional[MultimodalDataInputFormat] = None
666
667
668
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # Dummy sampling params for compatibility
669
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Rin Intachuen's avatar
Rin Intachuen committed
670
671
    # Dummy input embeds for compatibility
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
672
673
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
674
675
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
woodx's avatar
woodx committed
676
677
    # For cross-encoder requests
    is_cross_encoder_request: bool = False
678
679
    # Priority for the request
    priority: Optional[int] = None
680

681
682
683
    # For background responses (OpenAI responses API)
    background: bool = False

684
685
686
    # tracing context
    trace_context: Optional[Dict] = None

687
    def normalize_batch_and_arguments(self):
688
689
690
691
692
693
694
695
696
        # at least one of text, input_ids, or image should be provided
        if self.text is None and self.input_ids is None and self.image_data is None:
            raise ValueError(
                "At least one of text, input_ids, or image should be provided"
            )

        # text and input_ids cannot be provided at the same time
        if self.text is not None and self.input_ids is not None:
            raise ValueError("text and input_ids cannot be provided at the same time")
697

698
        # Derive the batch size
699
700
701
702
        self.batch_size = 0
        self.is_single = True

        # check the batch size of text
703
        if self.text is not None:
704
705
            if isinstance(self.text, list):
                self.batch_size += len(self.text)
706
                self.is_single = False
707
            else:
708
709
710
711
712
713
                self.batch_size += 1

        # check the batch size of input_ids
        if self.input_ids is not None:
            if isinstance(self.input_ids[0], list):
                self.batch_size += len(self.input_ids)
714
                self.is_single = False
715
            else:
716
717
                self.batch_size += 1

718
        # Fill in default arguments
719
        if self.is_single:
720
721
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
722
            if self.sampling_params is None:
723
                self.sampling_params = {}
724
            self.sampling_params["max_new_tokens"] = 0
725
726
727
728
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
729
730
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
731
            if self.sampling_params is None:
732
                self.sampling_params = [{}] * self.batch_size
733
734
            elif isinstance(self.sampling_params, dict):
                self.sampling_params = [self.sampling_params] * self.batch_size
735
            for i in range(self.batch_size):
736
                self.sampling_params[i]["max_new_tokens"] = 0
737

738
    def contains_mm_input(self) -> bool:
739
740
741
742
743
        return (
            has_valid_data(self.image_data)
            or has_valid_data(self.video_data)
            or has_valid_data(self.audio_data)
        )
744

745
    def __getitem__(self, i):
woodx's avatar
woodx committed
746
747
748
749
750
751
752
753
        if self.is_cross_encoder_request:
            return EmbeddingReqInput(
                text=[self.text[i]] if self.text is not None else None,
                sampling_params=self.sampling_params[i],
                rid=self.rid[i],
                is_cross_encoder_request=True,
            )

754
755
756
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
757
            image_data=self.image_data[i] if self.image_data is not None else None,
758
759
            audio_data=self.audio_data[i] if self.audio_data is not None else None,
            video_data=self.video_data[i] if self.video_data is not None else None,
760
761
762
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
        )
763
764
765


@dataclass
766
class TokenizedEmbeddingReqInput(BaseReq):
767
768
769
770
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
771
772
    # The image inputs
    image_inputs: dict
woodx's avatar
woodx committed
773
774
    # The token type ids
    token_type_ids: List[int]
775
776
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams
777
778
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None
779
780
    # Priority for the request
    priority: Optional[int] = None
781
782


783
@dataclass
784
class BatchTokenizedEmbeddingReqInput(BaseBatchReq):
785
786
787
788
789
790
791
792
793
794
795
796
797
    # The batch of tokenized embedding requests
    batch: List[TokenizedEmbeddingReqInput]

    def __len__(self):
        return len(self.batch)

    def __getitem__(self, i):
        return self.batch[i]

    def __iter__(self):
        return iter(self.batch)


Lianmin Zheng's avatar
Lianmin Zheng committed
798
@dataclass
799
class BatchTokenIDOutput(BaseBatchReq):
Lianmin Zheng's avatar
Lianmin Zheng committed
800
801
802
    # The finish reason
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
803
    decoded_texts: List[str]
804
805
    decode_ids: List[int]
    read_offsets: List[int]
806
    # Only used when `--skip-tokenizer-init` is on
807
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
808
    # Detokenization configs
Lianmin Zheng's avatar
Lianmin Zheng committed
809
    skip_special_tokens: List[bool]
810
    spaces_between_special_tokens: List[bool]
811
    no_stop_trim: List[bool]
812

Lianmin Zheng's avatar
Lianmin Zheng committed
813
814
815
816
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
817
818
    spec_verify_ct: List[int]

Lianmin Zheng's avatar
Lianmin Zheng committed
819
820
821
822
823
824
825
826
827
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
828
829
830
831
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Lianmin Zheng's avatar
Lianmin Zheng committed
832

833
    # Hidden states
834
835
    output_hidden_states: List[List[float]]

836
837
838
839
840
841
    # The information of placeholder tokens (e.g., image token)
    # idx is the index of the token in the prompt after expansion.
    # val is the length of padded tokens after expansion.
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]

Liangsheng Yin's avatar
Liangsheng Yin committed
842

843
@dataclass
844
class BatchMultimodalDecodeReq(BaseBatchReq):
845
846
847
848
849
850
851
852
853
854
855
    decoded_ids: List[int]
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    read_offsets: List[int]
    skip_special_tokens: List[bool]
    spaces_between_special_tokens: List[bool]
    image_resolutions: List[List[int]]
    resize_image_resolutions: List[List[int]]

856
857
858
859
860
861
    finished_reasons: List[BaseFinishReason]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
862

863
864
865
866
867
868
    # Placeholder token info
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]

    return_bytes: bool = False

869

Lianmin Zheng's avatar
Lianmin Zheng committed
870
@dataclass
871
class BatchStrOutput(BaseBatchReq):
Lianmin Zheng's avatar
Lianmin Zheng committed
872
873
    # The finish reason
    finished_reasons: List[dict]
874
    # The output decoded strings
875
    output_strs: List[str]
876
877
    # The token ids
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
878
879
880
881
882

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
883
    spec_verify_ct: List[int]
884

Lianmin Zheng's avatar
Lianmin Zheng committed
885
886
887
888
889
890
891
892
893
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
894
895
896
897
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Liangsheng Yin's avatar
Liangsheng Yin committed
898

899
    # Hidden states
900
901
    output_hidden_states: List[List[float]]

902
903
904
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]

Liangsheng Yin's avatar
Liangsheng Yin committed
905

906
@dataclass
907
class BatchMultimodalOutput(BaseBatchReq):
908
909
    # The finish reason
    finished_reasons: List[dict]
910
    decoded_ids: List[List[int]]
911
    # The outputs
912
913
914
915
916
917
918
    outputs: Union[List[str | bytes], List[List[Dict]]]

    # probability values for input tokens and output tokens
    input_token_logprobs_val: List[List[float]]
    input_token_logprobs_idx: List[List[int]]
    output_token_logprobs_val: List[List[float]]
    output_token_logprobs_idx: List[List[int]]
919
920
921
922
923

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
924

925
926
927
928
929
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]

    return_bytes: List[bool]

930

931
@dataclass
932
class BatchEmbeddingOutput(BaseBatchReq):
Lianmin Zheng's avatar
Lianmin Zheng committed
933
934
    # The finish reason
    finished_reasons: List[BaseFinishReason]
935
    # The output embedding
936
    embeddings: List[List[float]]
Lianmin Zheng's avatar
Lianmin Zheng committed
937
938
    # Token counts
    prompt_tokens: List[int]
939
    cached_tokens: List[int]
940
941
942
    # Placeholder token info
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]
943
944


945
@dataclass
946
class ClearHiCacheReqInput(BaseReq):
947
948
949
950
    pass


@dataclass
951
class ClearHiCacheReqOutput(BaseReq):
952
953
954
    success: bool


Liangsheng Yin's avatar
Liangsheng Yin committed
955
@dataclass
956
class FlushCacheReqInput(BaseReq):
957
    pass
Cody Yu's avatar
Cody Yu committed
958

959

960
@dataclass
961
class FlushCacheReqOutput(BaseReq):
962
963
964
    success: bool


965
@dataclass
966
class UpdateWeightFromDiskReqInput(BaseReq):
967
968
969
970
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None
971
972
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
973
974
    # Optional: Update weight version along with weights
    weight_version: Optional[str] = None
975
976
977
978
979
980
    # Whether to update weights asynchronously
    is_async: bool = False
    # Whether to empty torch cache
    torch_empty_cache: bool = False
    # Whether to keep the scheduler paused after weight update
    keep_pause: bool = False
981
982
983


@dataclass
984
class UpdateWeightFromDiskReqOutput(BaseReq):
985
986
    success: bool
    message: str
987
988
    # Number of paused requests during weight sync.
    num_paused_requests: Optional[int] = 0
989
990


991
@dataclass
992
class UpdateWeightsFromDistributedReqInput(BaseReq):
993
994
995
996
997
998
999
    names: List[str]
    dtypes: List[str]
    shapes: List[List[int]]
    # The group name
    group_name: str = "weight_update_group"
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
1000
1001
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
1002
1003
    # Optional: Update weight version along with weights
    weight_version: Optional[str] = None
1004
1005
1006


@dataclass
1007
class UpdateWeightsFromDistributedReqOutput(BaseReq):
1008
1009
1010
1011
    success: bool
    message: str


1012
@dataclass
1013
class UpdateWeightsFromTensorReqInput(BaseReq):
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
    """Update model weights from tensor input.

    - Tensors are serialized for transmission
    - Data is structured in JSON for easy transmission over HTTP
    """

    serialized_named_tensors: List[Union[str, bytes]]
    # Optional format specification for loading
    load_format: Optional[str] = None
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
1025
1026
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
1027
1028
    # Optional: Update weight version along with weights
    weight_version: Optional[str] = None
1029
1030
1031


@dataclass
1032
class UpdateWeightsFromTensorReqOutput(BaseReq):
1033
1034
1035
1036
    success: bool
    message: str


1037
@dataclass
1038
class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq):
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
    # The master address
    master_address: str
    # The ports for each rank's communication group
    ports: str
    # The rank in the communication group
    group_rank: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_send_group"
    # The backend
    backend: str = "nccl"


@dataclass
1054
class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
1055
1056
1057
1058
1059
    success: bool
    message: str


@dataclass
1060
class SendWeightsToRemoteInstanceReqInput(BaseReq):
1061
1062
1063
1064
1065
1066
1067
1068
1069
    # The master address
    master_address: str
    # The ports for each rank's communication group
    ports: str
    # The group name
    group_name: str = "weight_send_group"


@dataclass
1070
class SendWeightsToRemoteInstanceReqOutput(BaseReq):
1071
1072
1073
1074
    success: bool
    message: str


1075
@dataclass
1076
class InitWeightsUpdateGroupReqInput(BaseReq):
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
    # The master address
    master_address: str
    # The master port
    master_port: int
    # The rank offset
    rank_offset: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_update_group"
    # The backend
    backend: str = "nccl"


@dataclass
1092
class InitWeightsUpdateGroupReqOutput(BaseReq):
1093
1094
1095
1096
    success: bool
    message: str


1097
@dataclass
1098
class DestroyWeightsUpdateGroupReqInput(BaseReq):
1099
1100
1101
1102
    group_name: str = "weight_update_group"


@dataclass
1103
class DestroyWeightsUpdateGroupReqOutput(BaseReq):
1104
1105
1106
1107
    success: bool
    message: str


1108
@dataclass
1109
class UpdateWeightVersionReqInput(BaseReq):
1110
1111
1112
1113
1114
1115
    # The new weight version
    new_version: str
    # Whether to abort all running requests before updating
    abort_all_requests: bool = True


1116
@dataclass
1117
class GetWeightsByNameReqInput(BaseReq):
1118
1119
1120
1121
1122
    name: str
    truncate_size: int = 100


@dataclass
1123
class GetWeightsByNameReqOutput(BaseReq):
1124
1125
1126
    parameter: list


1127
@dataclass
1128
class ReleaseMemoryOccupationReqInput(BaseReq):
1129
1130
1131
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
1132
1133
1134


@dataclass
1135
class ReleaseMemoryOccupationReqOutput(BaseReq):
1136
    pass
1137
1138
1139


@dataclass
1140
class ResumeMemoryOccupationReqInput(BaseReq):
1141
1142
1143
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
1144
1145
1146


@dataclass
1147
class ResumeMemoryOccupationReqOutput(BaseReq):
1148
    pass
1149
1150


1151
@dataclass
1152
class SlowDownReqInput(BaseReq):
1153
1154
1155
1156
    forward_sleep_time: Optional[float]


@dataclass
1157
class SlowDownReqOutput(BaseReq):
1158
    pass
1159
1160


1161
@dataclass
1162
class AbortReq(BaseReq):
1163
1164
    # Whether to abort all requests
    abort_all: bool = False
1165
    # The finished reason data
1166
    finished_reason: Optional[Dict[str, Any]] = None
1167
    abort_reason: Optional[str] = None
1168
1169

    def __post_init__(self):
1170
1171
1172
        # FIXME: This is a hack to keep the same with the old code
        if self.rid is None:
            self.rid = ""
1173
1174


1175
@dataclass
1176
class GetInternalStateReq(BaseReq):
1177
    pass
1178
1179
1180


@dataclass
1181
class GetInternalStateReqOutput(BaseReq):
1182
1183
1184
1185
    internal_state: Dict[Any, Any]


@dataclass
1186
class SetInternalStateReq(BaseReq):
1187
1188
1189
1190
    server_args: Dict[str, Any]


@dataclass
1191
class SetInternalStateReqOutput(BaseReq):
1192
1193
1194
1195
1196
    updated: bool
    server_args: Dict[str, Any]


@dataclass
1197
class ProfileReqInput(BaseReq):
1198
1199
1200
1201
1202
    # The output directory
    output_dir: Optional[str] = None
    # If set, it profile as many as this number of steps.
    # If it is set, profiling is automatically stopped after this step, and
    # the caller doesn't need to run stop_profile.
1203
    start_step: Optional[int] = None
1204
    num_steps: Optional[int] = None
1205
1206
    activities: Optional[List[str]] = None
    profile_by_stage: bool = False
1207
1208
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
1209
1210
1211


class ProfileReqType(Enum):
1212
1213
    START_PROFILE = 1
    STOP_PROFILE = 2
1214
1215


1216
@dataclass
1217
class ProfileReq(BaseReq):
1218
1219
    type: ProfileReqType
    output_dir: Optional[str] = None
1220
    start_step: Optional[int] = None
1221
1222
    num_steps: Optional[int] = None
    activities: Optional[List[str]] = None
1223
    profile_by_stage: bool = False
1224
1225
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
1226
    profile_id: Optional[str] = None
1227
1228
1229


@dataclass
1230
class ProfileReqOutput(BaseReq):
1231
1232
1233
1234
    success: bool
    message: str


1235
@dataclass
1236
class FreezeGCReq(BaseReq):
1237
1238
1239
    pass


1240
@dataclass
1241
class ConfigureLoggingReq(BaseReq):
1242
    log_requests: Optional[bool] = None
1243
    log_requests_level: Optional[int] = None
1244
1245
    dump_requests_folder: Optional[str] = None
    dump_requests_threshold: Optional[int] = None
1246
    crash_dump_folder: Optional[str] = None
1247
1248


1249
@dataclass
1250
class OpenSessionReqInput(BaseReq):
1251
    capacity_of_str_len: int
1252
    session_id: Optional[str] = None
1253
1254
1255


@dataclass
1256
class CloseSessionReqInput(BaseReq):
1257
1258
1259
1260
    session_id: str


@dataclass
1261
class OpenSessionReqOutput(BaseReq):
1262
1263
    session_id: Optional[str]
    success: bool
YAMY's avatar
YAMY committed
1264
1265


1266
@dataclass
1267
class HealthCheckOutput(BaseReq):
1268
1269
1270
    pass


1271
class ExpertDistributionReqType(Enum):
1272
1273
1274
1275
1276
    START_RECORD = 1
    STOP_RECORD = 2
    DUMP_RECORD = 3


1277
1278
1279
1280
class ExpertDistributionReq(BaseReq):
    action: ExpertDistributionReqType


1281
@dataclass
1282
class ExpertDistributionReqOutput(BaseReq):
1283
    pass
1284
1285


YAMY's avatar
YAMY committed
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
@dataclass
class Function:
    description: Optional[str] = None
    name: Optional[str] = None
    parameters: Optional[object] = None


@dataclass
class Tool:
    function: Function
    type: Optional[str] = "function"


@dataclass
1300
class ParseFunctionCallReq(BaseReq):
YAMY's avatar
YAMY committed
1301
1302
1303
1304
1305
1306
1307
    text: str  # The text to parse.
    tools: List[Tool] = field(
        default_factory=list
    )  # A list of available function tools (name, parameters, etc.).
    tool_call_parser: Optional[str] = (
        None  # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
    )
1308
1309


Xihuai Wang's avatar
Xihuai Wang committed
1310
@dataclass
1311
class SeparateReasoningReqInput(BaseReq):
Xihuai Wang's avatar
Xihuai Wang committed
1312
1313
1314
1315
    text: str  # The text to parse.
    reasoning_parser: str  # Specify the parser type, e.g., "deepseek-r1".


1316
@dataclass
1317
class VertexGenerateReqInput(BaseReq):
1318
1319
    instances: List[dict]
    parameters: Optional[dict] = None
1320
1321
1322


@dataclass
1323
class RpcReqInput(BaseReq):
1324
1325
1326
1327
1328
    method: str
    parameters: Optional[Dict] = None


@dataclass
1329
class RpcReqOutput(BaseReq):
1330
1331
    success: bool
    message: str
1332
1333
1334


@dataclass
1335
class LoadLoRAAdapterReqInput(BaseReq):
1336
1337
1338
1339
    # The name of the lora module to newly loaded.
    lora_name: str
    # The path of loading.
    lora_path: str
1340
1341
    # Whether to pin the LoRA adapter in memory.
    pinned: bool = False
1342
1343
1344
1345
1346
1347
1348
1349
    # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
    lora_id: Optional[str] = None

    def to_ref(self) -> LoRARef:
        return LoRARef(
            lora_id=self.lora_id,
            lora_name=self.lora_name,
            lora_path=self.lora_path,
1350
            pinned=self.pinned,
1351
        )
1352
1353
1354


@dataclass
1355
class UnloadLoRAAdapterReqInput(BaseReq):
1356
1357
    # The name of lora module to unload.
    lora_name: str
1358
1359
1360
1361
1362
1363
1364
1365
    # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
    lora_id: Optional[str] = None

    def to_ref(self) -> LoRARef:
        return LoRARef(
            lora_id=self.lora_id,
            lora_name=self.lora_name,
        )
1366
1367
1368


@dataclass
1369
class LoRAUpdateOutput(BaseReq):
1370
1371
    success: bool
    error_message: Optional[str] = None
1372
    loaded_adapters: Optional[Dict[str, LoRARef]] = None
1373
1374


1375
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput
fzyzcjy's avatar
fzyzcjy committed
1376
1377


1378
@dataclass
1379
class MultiTokenizerRegisterReq(BaseBatchReq):
1380
1381
1382
1383
    ipc_name: Optional[str] = None


@dataclass
1384
class MultiTokenizerWrapper:
1385
    # FIXME(lsyin): remove this
1386
1387
1388
1389
    worker_id: int
    obj: Optional[Any] = None


fzyzcjy's avatar
fzyzcjy committed
1390
1391
1392
1393
1394
1395
class BlockReqType(Enum):
    BLOCK = 1
    UNBLOCK = 2


@dataclass
1396
class BlockReqInput(BaseReq):
fzyzcjy's avatar
fzyzcjy committed
1397
    type: BlockReqType
1398
1399
1400


@dataclass
1401
class GetLoadReqInput(BaseReq):
1402
1403
1404
1405
    pass


@dataclass
1406
class GetLoadReqOutput(BaseReq):
1407
1408
1409
1410
1411
1412
1413
    dp_rank: int
    num_reqs: int
    num_waiting_reqs: int
    num_tokens: int


@dataclass
1414
class WatchLoadUpdateReq(BaseReq):
1415
    loads: List[GetLoadReqOutput]
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441


def _check_all_req_types():
    """A helper function to check all request types are defined in this file."""
    import inspect
    import sys

    all_classes = inspect.getmembers(sys.modules[__name__], inspect.isclass)
    for class_type in all_classes:
        # check its name
        name = class_type[0]
        is_io_struct = (
            name.endswith("Req") or name.endswith("Input") or name.endswith("Output")
        )
        is_base_req = issubclass(class_type[1], BaseReq) or issubclass(
            class_type[1], BaseBatchReq
        )
        if is_io_struct and not is_base_req:
            raise ValueError(f"{name} is not a subclass of BaseReq or BaseBatchReq.")
        if is_base_req and not is_io_struct:
            raise ValueError(
                f"{name} is a subclass of BaseReq but not follow the naming convention."
            )


_check_all_req_types()