io_struct.py 39.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""
15
The definition of objects transferred between different
16
processes (TokenizerManager, DetokenizerManager, Scheduler).
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
"""

19
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import uuid
YAMY's avatar
YAMY committed
21
from dataclasses import dataclass, field
22
from enum import Enum
23
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
24

25
from sglang.srt.lora.lora_registry import LoRARef
26
from sglang.srt.managers.schedule_batch import BaseFinishReason
27
from sglang.srt.multimodal.mm_utils import has_valid_data
28
from sglang.srt.sampling.sampling_params import SamplingParams
29
from sglang.srt.utils import ImageData
30

31
# Handle serialization of Image for pydantic
32
33
34
35
if TYPE_CHECKING:
    from PIL.Image import Image
else:
    Image = Any
36

Lianmin Zheng's avatar
Lianmin Zheng committed
37

38
39
40
41
42
43
@dataclass
class SessionParams:
    id: Optional[str] = None
    rid: Optional[str] = None
    offset: Optional[int] = None
    replace: Optional[bool] = None
44
    drop_previous_output: Optional[bool] = None
45
46


47
48
# Type definitions for multimodal input data
# Individual data item types for each modality
49
ImageDataInputItem = Union[Image, str, ImageData, Dict]
50
51
52
53
54
55
56
57
58
59
60
61
AudioDataInputItem = Union[str, Dict]
VideoDataInputItem = Union[str, Dict]
# Union type for any multimodal data item
MultimodalDataInputItem = Union[
    ImageDataInputItem, VideoDataInputItem, AudioDataInputItem
]
# Format types supporting single items, lists, or nested lists for batch processing
MultimodalDataInputFormat = Union[
    List[List[MultimodalDataInputItem]],
    List[MultimodalDataInputItem],
    MultimodalDataInputItem,
]
62
63


Lianmin Zheng's avatar
Lianmin Zheng committed
64
65
@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
66
    # The input prompt. It can be a single prompt or a batch of prompts.
67
    text: Optional[Union[List[str], str]] = None
Rin Intachuen's avatar
Rin Intachuen committed
68
    # The token ids for text; one can specify either text or input_ids
69
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Rin Intachuen's avatar
Rin Intachuen committed
70
71
    # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
72
73
74
75
76
77
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
78
    image_data: Optional[MultimodalDataInputFormat] = None
79
    # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
80
81
82
    video_data: Optional[MultimodalDataInputFormat] = None
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
    audio_data: Optional[MultimodalDataInputFormat] = None
83
    # The sampling_params. See descriptions below.
84
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
85
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
86
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
87
    # Whether to return logprobs.
88
    return_logprob: Optional[Union[List[bool], bool]] = None
89
    # If return logprobs, the start location in the prompt for returning logprobs.
90
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
91
    logprob_start_len: Optional[Union[List[int], int]] = None
92
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
93
    top_logprobs_num: Optional[Union[List[int], int]] = None
94
95
    # If return logprobs, the token ids to return logprob for.
    token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
96
    # Whether to detokenize tokens in text in the returned logprobs.
97
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
98
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
99
    stream: bool = False
100
101
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
102
103
    # Whether to return hidden states
    return_hidden_states: Union[List[bool], bool] = False
104

105
106
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
107
108
109
    # Session info for continual prompting
    session_params: Optional[Union[List[Dict], Dict]] = None

110
    # The path to the LoRA adaptors
111
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
112
113
    # The uid of LoRA adaptors, should be initialized by tokenizer manager
    lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
114

115
116
117
118
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
    custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
119

120
    # For disaggregated inference
121
    bootstrap_host: Optional[Union[List[str], str]] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
122
    bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
123
    bootstrap_room: Optional[Union[List[int], int]] = None
124

125
126
127
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

128
129
130
    # For background responses (OpenAI responses API)
    background: bool = False

131
    def contains_mm_input(self) -> bool:
132
133
134
135
136
        return (
            has_valid_data(self.image_data)
            or has_valid_data(self.video_data)
            or has_valid_data(self.audio_data)
        )
137

138
    def normalize_batch_and_arguments(self):
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        """
        Normalize the batch size and arguments for the request.

        This method resolves various input formats and ensures all parameters
        are properly formatted as either single values or batches depending on the input.
        It also handles parallel sampling expansion and sets default values for
        unspecified parameters.

        Raises:
            ValueError: If inputs are not properly specified (e.g., none or all of
                       text, input_ids, input_embeds are provided)
        """
        self._validate_inputs()
        self._determine_batch_size()
        self._handle_parallel_sampling()

        if self.is_single:
            self._normalize_single_inputs()
        else:
            self._normalize_batch_inputs()

    def _validate_inputs(self):
        """Validate that the input configuration is valid."""
Rin Intachuen's avatar
Rin Intachuen committed
162
163
164
165
166
167
        if (
            self.text is None and self.input_ids is None and self.input_embeds is None
        ) or (
            self.text is not None
            and self.input_ids is not None
            and self.input_embeds is not None
168
        ):
Rin Intachuen's avatar
Rin Intachuen committed
169
170
171
            raise ValueError(
                "Either text, input_ids or input_embeds should be provided."
            )
172

173
174
    def _determine_batch_size(self):
        """Determine if this is a single example or a batch and the batch size."""
175
176
177
178
179
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
180
                self.is_single = False
181
                self.batch_size = len(self.text)
Rin Intachuen's avatar
Rin Intachuen committed
182
183
            self.input_embeds = None
        elif self.input_ids is not None:
Yinghai Lu's avatar
Yinghai Lu committed
184
185
            if len(self.input_ids) == 0:
                raise ValueError("input_ids cannot be empty.")
186
187
188
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
189
            else:
190
                self.is_single = False
191
                self.batch_size = len(self.input_ids)
Rin Intachuen's avatar
Rin Intachuen committed
192
193
194
195
196
197
            self.input_embeds = None
        else:
            if isinstance(self.input_embeds[0][0], float):
                self.is_single = True
                self.batch_size = 1
            else:
198
                self.is_single = False
Rin Intachuen's avatar
Rin Intachuen committed
199
                self.batch_size = len(self.input_embeds)
200

201
202
203
    def _handle_parallel_sampling(self):
        """Handle parallel sampling parameters and adjust batch size if needed."""
        # Determine parallel sample count
204
205
        if self.sampling_params is None:
            self.parallel_sample_num = 1
206
            return
207
        elif isinstance(self.sampling_params, dict):
208
209
210
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
211
212
213
214
215
            for sampling_params in self.sampling_params:
                if self.parallel_sample_num != sampling_params.get("n", 1):
                    raise ValueError(
                        "The parallel_sample_num should be the same for all samples in sample params."
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
216

217
        # If using parallel sampling with a single example, convert to batch
218
219
220
221
222
223
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]
224
225
            if self.input_embeds is not None:
                self.input_embeds = [self.input_embeds]
226

227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    def _normalize_single_inputs(self):
        """Normalize inputs for a single example."""
        if self.sampling_params is None:
            self.sampling_params = {}
        if self.rid is None:
            self.rid = uuid.uuid4().hex
        if self.return_logprob is None:
            self.return_logprob = False
        if self.logprob_start_len is None:
            self.logprob_start_len = -1
        if self.top_logprobs_num is None:
            self.top_logprobs_num = 0
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = None

    def _normalize_batch_inputs(self):
        """Normalize inputs for a batch of examples, including parallel sampling expansion."""
        # Calculate expanded batch size
        if self.parallel_sample_num == 1:
            num = self.batch_size
Lianmin Zheng's avatar
Lianmin Zheng committed
247
        else:
248
249
250
251
252
            # Expand parallel_sample_num
            num = self.batch_size * self.parallel_sample_num

        # Expand input based on type
        self._expand_inputs(num)
253
        self._normalize_rid(num)
254
255
        self._normalize_lora_paths(num)
        self._normalize_image_data(num)
256
        self._normalize_video_data(num)
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
        self._normalize_audio_data(num)
        self._normalize_sampling_params(num)
        self._normalize_logprob_params(num)
        self._normalize_custom_logit_processor(num)

    def _expand_inputs(self, num):
        """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
        if self.text is not None:
            if not isinstance(self.text, list):
                raise ValueError("Text should be a list for batch processing.")
            self.text = self.text * self.parallel_sample_num
        elif self.input_ids is not None:
            if not isinstance(self.input_ids, list) or not isinstance(
                self.input_ids[0], list
            ):
                raise ValueError(
                    "input_ids should be a list of lists for batch processing."
                )
            self.input_ids = self.input_ids * self.parallel_sample_num
        elif self.input_embeds is not None:
            if not isinstance(self.input_embeds, list):
                raise ValueError("input_embeds should be a list for batch processing.")
            self.input_embeds = self.input_embeds * self.parallel_sample_num

    def _normalize_lora_paths(self, num):
        """Normalize LoRA paths for batch processing."""
        if self.lora_path is not None:
            if isinstance(self.lora_path, str):
                self.lora_path = [self.lora_path] * num
            elif isinstance(self.lora_path, list):
                self.lora_path = self.lora_path * self.parallel_sample_num
288
            else:
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
                raise ValueError("lora_path should be a list or a string.")

    def _normalize_image_data(self, num):
        """Normalize image data for batch processing."""
        if self.image_data is None:
            self.image_data = [None] * num
        elif not isinstance(self.image_data, list):
            # Single image, convert to list of single-image lists
            self.image_data = [[self.image_data]] * num
            self.modalities = ["image"] * num
        elif isinstance(self.image_data, list):
            if len(self.image_data) != self.batch_size:
                raise ValueError(
                    "The length of image_data should be equal to the batch size."
                )

            self.modalities = []
            if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
                # Already a list of lists, keep as is
                for i in range(len(self.image_data)):
                    if self.image_data[i] is None or self.image_data[i] == [None]:
                        self.modalities.append(None)
                    elif len(self.image_data[i]) == 1:
                        self.modalities.append("image")
                    elif len(self.image_data[i]) > 1:
                        self.modalities.append("multi-images")
315
316
317
                    else:
                        # Ensure len(self.modalities) == len(self.image_data)
                        self.modalities.append(None)
318
                # Expand parallel_sample_num
319
320
                self.image_data = self.image_data * self.parallel_sample_num
                self.modalities = self.modalities * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
321
            else:
322
323
324
325
326
327
                # List of images for a batch, wrap each in a list
                wrapped_images = [[img] for img in self.image_data]
                # Expand for parallel sampling
                self.image_data = wrapped_images * self.parallel_sample_num
                self.modalities = ["image"] * num

328
329
330
331
332
333
334
335
336
    def _normalize_video_data(self, num):
        """Normalize video data for batch processing."""
        if self.video_data is None:
            self.video_data = [None] * num
        elif not isinstance(self.video_data, list):
            self.video_data = [self.video_data] * num
        elif isinstance(self.video_data, list):
            self.video_data = self.video_data * self.parallel_sample_num

337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    def _normalize_audio_data(self, num):
        """Normalize audio data for batch processing."""
        if self.audio_data is None:
            self.audio_data = [None] * num
        elif not isinstance(self.audio_data, list):
            self.audio_data = [self.audio_data] * num
        elif isinstance(self.audio_data, list):
            self.audio_data = self.audio_data * self.parallel_sample_num

    def _normalize_sampling_params(self, num):
        """Normalize sampling parameters for batch processing."""
        if self.sampling_params is None:
            self.sampling_params = [{}] * num
        elif isinstance(self.sampling_params, dict):
            self.sampling_params = [self.sampling_params] * num
        else:  # Already a list
            self.sampling_params = self.sampling_params * self.parallel_sample_num

    def _normalize_rid(self, num):
        """Normalize request IDs for batch processing."""
        if self.rid is None:
            self.rid = [uuid.uuid4().hex for _ in range(num)]
359
360
361
362
        elif isinstance(self.rid, str):
            new_rids = [f"{self.rid}_{i}" for i in range(num)]
            self.rid = new_rids
        elif isinstance(self.rid, list):
363
364
365
            # Note: the length of rid shall be the same as the batch_size,
            # as the rid would be expanded for parallel sampling in tokenizer_manager
            if len(self.rid) != self.batch_size:
366
367
368
369
370
                raise ValueError(
                    "The specified rids length mismatch with the batch_size for batch processing."
                )
        else:
            raise ValueError("The rid should be a string or a list of strings.")
371
372
373
374
375
376
377
378
379
380

    def _normalize_logprob_params(self, num):
        """Normalize logprob-related parameters for batch processing."""

        # Helper function to normalize a parameter
        def normalize_param(param, default_value, param_name):
            if param is None:
                return [default_value] * num
            elif not isinstance(param, list):
                return [param] * num
381
            else:
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
                if self.parallel_sample_num > 1:
                    raise ValueError(
                        f"Cannot use list {param_name} with parallel_sample_num > 1"
                    )
                return param

        # Normalize each logprob parameter
        self.return_logprob = normalize_param(
            self.return_logprob, False, "return_logprob"
        )
        self.logprob_start_len = normalize_param(
            self.logprob_start_len, -1, "logprob_start_len"
        )
        self.top_logprobs_num = normalize_param(
            self.top_logprobs_num, 0, "top_logprobs_num"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
398

399
400
401
402
403
404
405
406
407
408
409
410
411
        # Handle token_ids_logprob specially due to its nested structure
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = [None] * num
        elif not isinstance(self.token_ids_logprob, list):
            self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
        elif not isinstance(self.token_ids_logprob[0], list):
            self.token_ids_logprob = [
                copy.deepcopy(self.token_ids_logprob) for _ in range(num)
            ]
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list token_ids_logprob with parallel_sample_num > 1"
            )
412

413
414
415
416
417
418
419
420
421
422
    def _normalize_custom_logit_processor(self, num):
        """Normalize custom logit processor for batch processing."""
        if self.custom_logit_processor is None:
            self.custom_logit_processor = [None] * num
        elif not isinstance(self.custom_logit_processor, list):
            self.custom_logit_processor = [self.custom_logit_processor] * num
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list custom_logit_processor with parallel_sample_num > 1"
            )
423

424
425
    def _validate_session_params(self):
        """Validate that session parameters are properly formatted."""
426
        if self.session_params is not None:
427
            if not isinstance(self.session_params, dict) and not isinstance(
428
                self.session_params[0], dict
429
430
            ):
                raise ValueError("Session params must be a dict or a list of dicts.")
431

432
    def regenerate_rid(self):
433
        """Generate a new request ID and return it."""
434
435
436
437
438
439
440
        self.rid = uuid.uuid4().hex
        return self.rid

    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
441
442
443
            input_embeds=(
                self.input_embeds[i] if self.input_embeds is not None else None
            ),
444
            image_data=self.image_data[i],
445
            video_data=self.video_data[i],
Mick's avatar
Mick committed
446
            audio_data=self.audio_data[i],
447
448
449
450
451
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
452
            token_ids_logprob=self.token_ids_logprob[i],
453
454
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
455
            log_metrics=self.log_metrics,
456
457
            modalities=self.modalities[i] if self.modalities else None,
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
458
            lora_id=self.lora_id[i] if self.lora_id is not None else None,
459
460
461
462
463
            custom_logit_processor=(
                self.custom_logit_processor[i]
                if self.custom_logit_processor is not None
                else None
            ),
464
465
466
467
468
            return_hidden_states=(
                self.return_hidden_states[i]
                if isinstance(self.return_hidden_states, list)
                else self.return_hidden_states
            ),
469
            # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
470
471
472
            bootstrap_host=(
                self.bootstrap_host[i] if self.bootstrap_host is not None else None
            ),
473
474
475
            bootstrap_port=(
                self.bootstrap_port[i] if self.bootstrap_port is not None else None
            ),
476
477
478
            bootstrap_room=(
                self.bootstrap_room[i] if self.bootstrap_room is not None else None
            ),
479
480
481
            data_parallel_rank=(
                self.data_parallel_rank if self.data_parallel_rank is not None else None
            ),
482
483
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
484
485
486

@dataclass
class TokenizedGenerateReqInput:
487
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
488
    rid: str
489
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
490
    input_text: str
491
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
492
    input_ids: List[int]
Mick's avatar
Mick committed
493
494
    # The multimodal inputs
    mm_inputs: dict
495
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
496
    sampling_params: SamplingParams
497
    # Whether to return the logprobs
498
    return_logprob: bool
499
    # If return logprobs, the start location in the prompt for returning logprobs.
500
    logprob_start_len: int
501
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
502
    top_logprobs_num: int
503
504
    # If return logprobs, the token id to return logprob for
    token_ids_logprob: List[int]
505
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
506
507
    stream: bool

508
    # LoRA related
509
    lora_id: Optional[str] = None  # None means just use the base model
Rin Intachuen's avatar
Rin Intachuen committed
510
511
    # The input embeds
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
512

513
514
    # Session info for continual prompting
    session_params: Optional[SessionParams] = None
515

516
517
518
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
519
520
    custom_logit_processor: Optional[str] = None

521
522
523
    # Whether to return hidden states
    return_hidden_states: bool = False

524
525
    # For disaggregated inference
    bootstrap_host: Optional[str] = None
526
    bootstrap_port: Optional[int] = None
527
528
    bootstrap_room: Optional[int] = None

529
530
531
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

532
533
534
    # For dp balance
    dp_balance_id: int = -1

Lianmin Zheng's avatar
Lianmin Zheng committed
535

536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
@dataclass
class BatchTokenizedGenerateReqInput:
    # The batch of tokenized requests
    batch: List[TokenizedGenerateReqInput]

    def __len__(self):
        return len(self.batch)

    def __getitem__(self, i):
        return self.batch[i]

    def __iter__(self):
        return iter(self.batch)


551
552
553
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
woodx's avatar
woodx committed
554
    text: Optional[Union[List[List[str]], List[str], str]] = None
555
556
557
558
559
560
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
561
    image_data: Optional[MultimodalDataInputFormat] = None
562
    # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
563
    video_data: Optional[MultimodalDataInputFormat] = None
564
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
565
    audio_data: Optional[MultimodalDataInputFormat] = None
566
567
568
569
570
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
571
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Rin Intachuen's avatar
Rin Intachuen committed
572
573
    # Dummy input embeds for compatibility
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
574
575
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
576
577
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
woodx's avatar
woodx committed
578
579
    # For cross-encoder requests
    is_cross_encoder_request: bool = False
580

581
582
583
    # For background responses (OpenAI responses API)
    background: bool = False

584
    def normalize_batch_and_arguments(self):
585
586
587
588
589
590
591
592
593
        # at least one of text, input_ids, or image should be provided
        if self.text is None and self.input_ids is None and self.image_data is None:
            raise ValueError(
                "At least one of text, input_ids, or image should be provided"
            )

        # text and input_ids cannot be provided at the same time
        if self.text is not None and self.input_ids is not None:
            raise ValueError("text and input_ids cannot be provided at the same time")
594

595
        # Derive the batch size
596
597
598
599
        self.batch_size = 0
        self.is_single = True

        # check the batch size of text
600
        if self.text is not None:
601
602
            if isinstance(self.text, list):
                self.batch_size += len(self.text)
603
                self.is_single = False
604
            else:
605
606
607
608
609
610
                self.batch_size += 1

        # check the batch size of input_ids
        if self.input_ids is not None:
            if isinstance(self.input_ids[0], list):
                self.batch_size += len(self.input_ids)
611
                self.is_single = False
612
            else:
613
614
                self.batch_size += 1

615
        # Fill in default arguments
616
        if self.is_single:
617
618
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
619
            if self.sampling_params is None:
620
                self.sampling_params = {}
621
            self.sampling_params["max_new_tokens"] = 0
622
623
624
625
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
626
627
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
628
            if self.sampling_params is None:
629
                self.sampling_params = [{}] * self.batch_size
630
631
            elif isinstance(self.sampling_params, dict):
                self.sampling_params = [self.sampling_params] * self.batch_size
632
            for i in range(self.batch_size):
633
                self.sampling_params[i]["max_new_tokens"] = 0
634

635
636
637
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid
638

639
    def contains_mm_input(self) -> bool:
640
641
642
643
644
        return (
            has_valid_data(self.image_data)
            or has_valid_data(self.video_data)
            or has_valid_data(self.audio_data)
        )
645

646
    def __getitem__(self, i):
woodx's avatar
woodx committed
647
648
649
650
651
652
653
654
        if self.is_cross_encoder_request:
            return EmbeddingReqInput(
                text=[self.text[i]] if self.text is not None else None,
                sampling_params=self.sampling_params[i],
                rid=self.rid[i],
                is_cross_encoder_request=True,
            )

655
656
657
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
658
            image_data=self.image_data[i] if self.image_data is not None else None,
659
660
            audio_data=self.audio_data[i] if self.audio_data is not None else None,
            video_data=self.video_data[i] if self.video_data is not None else None,
661
662
663
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
        )
664
665
666


@dataclass
667
class TokenizedEmbeddingReqInput:
668
669
670
671
672
673
    # The request id
    rid: str
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
674
675
    # The image inputs
    image_inputs: dict
woodx's avatar
woodx committed
676
677
    # The token type ids
    token_type_ids: List[int]
678
679
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams
680
681
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None
682
683
    # For dp balance
    dp_balance_id: int = -1
684
685


686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
@dataclass
class BatchTokenizedEmbeddingReqInput:
    # The batch of tokenized embedding requests
    batch: List[TokenizedEmbeddingReqInput]

    def __len__(self):
        return len(self.batch)

    def __getitem__(self, i):
        return self.batch[i]

    def __iter__(self):
        return iter(self.batch)


Lianmin Zheng's avatar
Lianmin Zheng committed
701
702
@dataclass
class BatchTokenIDOut:
703
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
704
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
705
706
707
    # The finish reason
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
708
    decoded_texts: List[str]
709
710
    decode_ids: List[int]
    read_offsets: List[int]
711
    # Only used when `--skip-tokenizer-init` is on
712
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
713
    # Detokenization configs
Lianmin Zheng's avatar
Lianmin Zheng committed
714
    skip_special_tokens: List[bool]
715
    spaces_between_special_tokens: List[bool]
716
    no_stop_trim: List[bool]
717

Lianmin Zheng's avatar
Lianmin Zheng committed
718
719
720
721
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
722
723
    spec_verify_ct: List[int]

Lianmin Zheng's avatar
Lianmin Zheng committed
724
725
726
727
728
729
730
731
732
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
733
734
735
736
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Lianmin Zheng's avatar
Lianmin Zheng committed
737

738
    # Hidden states
739
740
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
741

742
743
744
745
@dataclass
class BatchMultimodalDecodeReq:
    # The request id
    rids: List[str]
746
747
748
749
750
751
    finished_reasons: List[BaseFinishReason]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
752
753


Lianmin Zheng's avatar
Lianmin Zheng committed
754
755
@dataclass
class BatchStrOut:
756
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
757
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
758
759
    # The finish reason
    finished_reasons: List[dict]
760
    # The output decoded strings
761
    output_strs: List[str]
762
763
    # The token ids
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
764
765
766
767
768

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
769
    spec_verify_ct: List[int]
770

Lianmin Zheng's avatar
Lianmin Zheng committed
771
772
773
774
775
776
777
778
779
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
780
781
782
783
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Liangsheng Yin's avatar
Liangsheng Yin committed
784

785
    # Hidden states
786
787
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
788

789
790
791
792
@dataclass
class BatchMultimodalOut:
    # The request id
    rids: List[str]
793
794
795
796
797
798
799
800
801
    # The finish reason
    finished_reasons: List[dict]
    # The outputs
    outputs: List[List[Dict]]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
802
803


804
805
@dataclass
class BatchEmbeddingOut:
806
    # The request id
807
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
808
809
    # The finish reason
    finished_reasons: List[BaseFinishReason]
810
    # The output embedding
811
    embeddings: List[List[float]]
Lianmin Zheng's avatar
Lianmin Zheng committed
812
813
    # Token counts
    prompt_tokens: List[int]
814
    cached_tokens: List[int]
815
816


817
818
819
820
821
822
823
824
825
826
@dataclass
class ClearHiCacheReqInput:
    pass


@dataclass
class ClearHiCacheReqOutput:
    success: bool


Liangsheng Yin's avatar
Liangsheng Yin committed
827
@dataclass
828
class FlushCacheReqInput:
829
    pass
Cody Yu's avatar
Cody Yu committed
830

831

832
833
834
835
836
@dataclass
class FlushCacheReqOutput:
    success: bool


837
@dataclass
Chayenne's avatar
Chayenne committed
838
class UpdateWeightFromDiskReqInput:
839
840
841
842
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None
843
844
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
845
846
    # Optional: Update weight version along with weights
    weight_version: Optional[str] = None
847
848
849


@dataclass
Chayenne's avatar
Chayenne committed
850
class UpdateWeightFromDiskReqOutput:
851
852
    success: bool
    message: str
853
854
    # Number of paused requests during weight sync.
    num_paused_requests: Optional[int] = 0
855
856


857
858
@dataclass
class UpdateWeightsFromDistributedReqInput:
859
860
861
862
863
864
865
    names: List[str]
    dtypes: List[str]
    shapes: List[List[int]]
    # The group name
    group_name: str = "weight_update_group"
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
866
867
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
868
869
    # Optional: Update weight version along with weights
    weight_version: Optional[str] = None
870
871
872
873
874
875
876
877


@dataclass
class UpdateWeightsFromDistributedReqOutput:
    success: bool
    message: str


878
879
@dataclass
class UpdateWeightsFromTensorReqInput:
880
881
882
883
884
885
886
887
888
889
890
    """Update model weights from tensor input.

    - Tensors are serialized for transmission
    - Data is structured in JSON for easy transmission over HTTP
    """

    serialized_named_tensors: List[Union[str, bytes]]
    # Optional format specification for loading
    load_format: Optional[str] = None
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
891
892
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
893
894
    # Optional: Update weight version along with weights
    weight_version: Optional[str] = None
895
896
897
898
899
900
901
902


@dataclass
class UpdateWeightsFromTensorReqOutput:
    success: bool
    message: str


903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
@dataclass
class InitWeightsUpdateGroupReqInput:
    # The master address
    master_address: str
    # The master port
    master_port: int
    # The rank offset
    rank_offset: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_update_group"
    # The backend
    backend: str = "nccl"


@dataclass
class InitWeightsUpdateGroupReqOutput:
    success: bool
    message: str


925
926
927
928
929
930
931
932
@dataclass
class UpdateWeightVersionReqInput:
    # The new weight version
    new_version: str
    # Whether to abort all running requests before updating
    abort_all_requests: bool = True


933
934
935
936
937
938
939
940
941
942
943
@dataclass
class GetWeightsByNameReqInput:
    name: str
    truncate_size: int = 100


@dataclass
class GetWeightsByNameReqOutput:
    parameter: list


944
945
@dataclass
class ReleaseMemoryOccupationReqInput:
946
947
948
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
949
950
951
952


@dataclass
class ReleaseMemoryOccupationReqOutput:
953
    pass
954
955
956
957


@dataclass
class ResumeMemoryOccupationReqInput:
958
959
960
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
961
962
963
964


@dataclass
class ResumeMemoryOccupationReqOutput:
965
    pass
966
967


968
969
970
971
972
973
974
@dataclass
class SlowDownReqInput:
    forward_sleep_time: Optional[float]


@dataclass
class SlowDownReqOutput:
975
    pass
976
977


978
979
@dataclass
class AbortReq:
980
    # The request id
981
982
983
    rid: str = ""
    # Whether to abort all requests
    abort_all: bool = False
984
    # The finished reason data
985
    finished_reason: Optional[Dict[str, Any]] = None
986
987


988
989
@dataclass
class GetInternalStateReq:
990
    pass
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015


@dataclass
class GetInternalStateReqOutput:
    internal_state: Dict[Any, Any]


@dataclass
class SetInternalStateReq:
    server_args: Dict[str, Any]


@dataclass
class SetInternalStateReqOutput:
    updated: bool
    server_args: Dict[str, Any]


@dataclass
class ProfileReqInput:
    # The output directory
    output_dir: Optional[str] = None
    # If set, it profile as many as this number of steps.
    # If it is set, profiling is automatically stopped after this step, and
    # the caller doesn't need to run stop_profile.
1016
    start_step: Optional[int] = None
1017
    num_steps: Optional[int] = None
1018
1019
    activities: Optional[List[str]] = None
    profile_by_stage: bool = False
1020
1021
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
1022
1023
1024


class ProfileReqType(Enum):
1025
1026
    START_PROFILE = 1
    STOP_PROFILE = 2
1027
1028


1029
1030
1031
1032
@dataclass
class ProfileReq:
    type: ProfileReqType
    output_dir: Optional[str] = None
1033
    start_step: Optional[int] = None
1034
1035
    num_steps: Optional[int] = None
    activities: Optional[List[str]] = None
1036
    profile_by_stage: bool = False
1037
1038
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
1039
    profile_id: Optional[str] = None
1040
1041
1042
1043
1044
1045
1046
1047


@dataclass
class ProfileReqOutput:
    success: bool
    message: str


1048
1049
1050
1051
1052
@dataclass
class FreezeGCReq:
    pass


1053
1054
1055
@dataclass
class ConfigureLoggingReq:
    log_requests: Optional[bool] = None
1056
    log_requests_level: Optional[int] = None
1057
1058
1059
1060
    dump_requests_folder: Optional[str] = None
    dump_requests_threshold: Optional[int] = None


1061
1062
1063
@dataclass
class OpenSessionReqInput:
    capacity_of_str_len: int
1064
    session_id: Optional[str] = None
1065
1066
1067
1068
1069
1070
1071
1072
1073


@dataclass
class CloseSessionReqInput:
    session_id: str


@dataclass
class OpenSessionReqOutput:
1074
1075
    session_id: Optional[str]
    success: bool
YAMY's avatar
YAMY committed
1076
1077


1078
1079
1080
1081
1082
@dataclass
class HealthCheckOutput:
    pass


1083
1084
1085
1086
1087
1088
1089
1090
class ExpertDistributionReq(Enum):
    START_RECORD = 1
    STOP_RECORD = 2
    DUMP_RECORD = 3


@dataclass
class ExpertDistributionReqOutput:
1091
    pass
1092
1093


YAMY's avatar
YAMY committed
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
@dataclass
class Function:
    description: Optional[str] = None
    name: Optional[str] = None
    parameters: Optional[object] = None


@dataclass
class Tool:
    function: Function
    type: Optional[str] = "function"


@dataclass
1108
class ParseFunctionCallReq:
YAMY's avatar
YAMY committed
1109
1110
1111
1112
1113
1114
1115
    text: str  # The text to parse.
    tools: List[Tool] = field(
        default_factory=list
    )  # A list of available function tools (name, parameters, etc.).
    tool_call_parser: Optional[str] = (
        None  # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
    )
1116
1117


Xihuai Wang's avatar
Xihuai Wang committed
1118
1119
1120
1121
1122
1123
@dataclass
class SeparateReasoningReqInput:
    text: str  # The text to parse.
    reasoning_parser: str  # Specify the parser type, e.g., "deepseek-r1".


1124
1125
1126
1127
@dataclass
class VertexGenerateReqInput:
    instances: List[dict]
    parameters: Optional[dict] = None
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139


@dataclass
class RpcReqInput:
    method: str
    parameters: Optional[Dict] = None


@dataclass
class RpcReqOutput:
    success: bool
    message: str
1140
1141
1142
1143
1144
1145
1146
1147


@dataclass
class LoadLoRAAdapterReqInput:
    # The name of the lora module to newly loaded.
    lora_name: str
    # The path of loading.
    lora_path: str
1148
1149
    # Whether to pin the LoRA adapter in memory.
    pinned: bool = False
1150
1151
1152
1153
1154
1155
1156
1157
    # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
    lora_id: Optional[str] = None

    def to_ref(self) -> LoRARef:
        return LoRARef(
            lora_id=self.lora_id,
            lora_name=self.lora_name,
            lora_path=self.lora_path,
1158
            pinned=self.pinned,
1159
        )
1160
1161
1162
1163
1164
1165


@dataclass
class UnloadLoRAAdapterReqInput:
    # The name of lora module to unload.
    lora_name: str
1166
1167
1168
1169
1170
1171
1172
1173
    # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
    lora_id: Optional[str] = None

    def to_ref(self) -> LoRARef:
        return LoRARef(
            lora_id=self.lora_id,
            lora_name=self.lora_name,
        )
1174
1175
1176
1177
1178
1179


@dataclass
class LoRAUpdateResult:
    success: bool
    error_message: Optional[str] = None
1180
    loaded_adapters: Optional[Dict[str, LoRARef]] = None
1181
1182
1183


LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
fzyzcjy's avatar
fzyzcjy committed
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193


class BlockReqType(Enum):
    BLOCK = 1
    UNBLOCK = 2


@dataclass
class BlockReqInput:
    type: BlockReqType