io_struct.py 37.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""
15
The definition of objects transferred between different
16
processes (TokenizerManager, DetokenizerManager, Scheduler).
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
"""

19
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import uuid
YAMY's avatar
YAMY committed
21
from dataclasses import dataclass, field
22
from enum import Enum
23
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
24

25
from sglang.srt.lora.lora_registry import LoRARef
26
from sglang.srt.managers.schedule_batch import BaseFinishReason
27
from sglang.srt.multimodal.mm_utils import has_valid_data
28
from sglang.srt.sampling.sampling_params import SamplingParams
29
from sglang.srt.utils import ImageData
30

31
# Handle serialization of Image for pydantic
32
33
34
35
if TYPE_CHECKING:
    from PIL.Image import Image
else:
    Image = Any
36

Lianmin Zheng's avatar
Lianmin Zheng committed
37

38
39
40
41
42
43
@dataclass
class SessionParams:
    id: Optional[str] = None
    rid: Optional[str] = None
    offset: Optional[int] = None
    replace: Optional[bool] = None
44
    drop_previous_output: Optional[bool] = None
45
46


47
48
# Type definitions for multimodal input data
# Individual data item types for each modality
49
ImageDataInputItem = Union[Image, str, ImageData, Dict]
50
51
52
53
54
55
56
57
58
59
60
61
AudioDataInputItem = Union[str, Dict]
VideoDataInputItem = Union[str, Dict]
# Union type for any multimodal data item
MultimodalDataInputItem = Union[
    ImageDataInputItem, VideoDataInputItem, AudioDataInputItem
]
# Format types supporting single items, lists, or nested lists for batch processing
MultimodalDataInputFormat = Union[
    List[List[MultimodalDataInputItem]],
    List[MultimodalDataInputItem],
    MultimodalDataInputItem,
]
62
63


Lianmin Zheng's avatar
Lianmin Zheng committed
64
65
@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
66
    # The input prompt. It can be a single prompt or a batch of prompts.
67
    text: Optional[Union[List[str], str]] = None
Rin Intachuen's avatar
Rin Intachuen committed
68
    # The token ids for text; one can specify either text or input_ids
69
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Rin Intachuen's avatar
Rin Intachuen committed
70
71
    # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
72
73
74
75
76
77
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
78
    image_data: Optional[MultimodalDataInputFormat] = None
79
    # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
80
81
82
    video_data: Optional[MultimodalDataInputFormat] = None
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
    audio_data: Optional[MultimodalDataInputFormat] = None
83
    # The sampling_params. See descriptions below.
84
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
85
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
86
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
87
    # Whether to return logprobs.
88
    return_logprob: Optional[Union[List[bool], bool]] = None
89
    # If return logprobs, the start location in the prompt for returning logprobs.
90
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
91
    logprob_start_len: Optional[Union[List[int], int]] = None
92
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
93
    top_logprobs_num: Optional[Union[List[int], int]] = None
94
95
    # If return logprobs, the token ids to return logprob for.
    token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
96
    # Whether to detokenize tokens in text in the returned logprobs.
97
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
98
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
99
    stream: bool = False
100
101
102
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True

103
104
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
105
    # The path to the LoRA adaptors
106
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
107
108
    # The uid of LoRA adaptors, should be initialized by tokenizer manager
    lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
109

110
111
    # Session info for continual prompting
    session_params: Optional[Union[List[Dict], Dict]] = None
112

113
114
115
116
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
    custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
117

118
    # Whether to return hidden states
119
    return_hidden_states: Union[List[bool], bool] = False
120

121
    # For disaggregated inference
122
    bootstrap_host: Optional[Union[List[str], str]] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
123
    bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
124
    bootstrap_room: Optional[Union[List[int], int]] = None
125

126
127
128
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

129
    def contains_mm_input(self) -> bool:
130
131
132
133
134
        return (
            has_valid_data(self.image_data)
            or has_valid_data(self.video_data)
            or has_valid_data(self.audio_data)
        )
135

136
    def normalize_batch_and_arguments(self):
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        """
        Normalize the batch size and arguments for the request.

        This method resolves various input formats and ensures all parameters
        are properly formatted as either single values or batches depending on the input.
        It also handles parallel sampling expansion and sets default values for
        unspecified parameters.

        Raises:
            ValueError: If inputs are not properly specified (e.g., none or all of
                       text, input_ids, input_embeds are provided)
        """
        self._validate_inputs()
        self._determine_batch_size()
        self._handle_parallel_sampling()

        if self.is_single:
            self._normalize_single_inputs()
        else:
            self._normalize_batch_inputs()

    def _validate_inputs(self):
        """Validate that the input configuration is valid."""
Rin Intachuen's avatar
Rin Intachuen committed
160
161
162
163
164
165
        if (
            self.text is None and self.input_ids is None and self.input_embeds is None
        ) or (
            self.text is not None
            and self.input_ids is not None
            and self.input_embeds is not None
166
        ):
Rin Intachuen's avatar
Rin Intachuen committed
167
168
169
            raise ValueError(
                "Either text, input_ids or input_embeds should be provided."
            )
170

171
172
    def _determine_batch_size(self):
        """Determine if this is a single example or a batch and the batch size."""
173
174
175
176
177
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
178
                self.is_single = False
179
                self.batch_size = len(self.text)
Rin Intachuen's avatar
Rin Intachuen committed
180
181
            self.input_embeds = None
        elif self.input_ids is not None:
Yinghai Lu's avatar
Yinghai Lu committed
182
183
            if len(self.input_ids) == 0:
                raise ValueError("input_ids cannot be empty.")
184
185
186
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
187
            else:
188
                self.is_single = False
189
                self.batch_size = len(self.input_ids)
Rin Intachuen's avatar
Rin Intachuen committed
190
191
192
193
194
195
            self.input_embeds = None
        else:
            if isinstance(self.input_embeds[0][0], float):
                self.is_single = True
                self.batch_size = 1
            else:
196
                self.is_single = False
Rin Intachuen's avatar
Rin Intachuen committed
197
                self.batch_size = len(self.input_embeds)
198

199
200
201
    def _handle_parallel_sampling(self):
        """Handle parallel sampling parameters and adjust batch size if needed."""
        # Determine parallel sample count
202
203
        if self.sampling_params is None:
            self.parallel_sample_num = 1
204
            return
205
        elif isinstance(self.sampling_params, dict):
206
207
208
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
209
210
211
212
213
            for sampling_params in self.sampling_params:
                if self.parallel_sample_num != sampling_params.get("n", 1):
                    raise ValueError(
                        "The parallel_sample_num should be the same for all samples in sample params."
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
214

215
        # If using parallel sampling with a single example, convert to batch
216
217
218
219
220
221
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]
222
223
            if self.input_embeds is not None:
                self.input_embeds = [self.input_embeds]
224

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    def _normalize_single_inputs(self):
        """Normalize inputs for a single example."""
        if self.sampling_params is None:
            self.sampling_params = {}
        if self.rid is None:
            self.rid = uuid.uuid4().hex
        if self.return_logprob is None:
            self.return_logprob = False
        if self.logprob_start_len is None:
            self.logprob_start_len = -1
        if self.top_logprobs_num is None:
            self.top_logprobs_num = 0
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = None

    def _normalize_batch_inputs(self):
        """Normalize inputs for a batch of examples, including parallel sampling expansion."""
        # Calculate expanded batch size
        if self.parallel_sample_num == 1:
            num = self.batch_size
Lianmin Zheng's avatar
Lianmin Zheng committed
245
        else:
246
247
248
249
250
            # Expand parallel_sample_num
            num = self.batch_size * self.parallel_sample_num

        # Expand input based on type
        self._expand_inputs(num)
251
        self._normalize_rid(num)
252
253
        self._normalize_lora_paths(num)
        self._normalize_image_data(num)
254
        self._normalize_video_data(num)
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        self._normalize_audio_data(num)
        self._normalize_sampling_params(num)
        self._normalize_logprob_params(num)
        self._normalize_custom_logit_processor(num)

    def _expand_inputs(self, num):
        """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
        if self.text is not None:
            if not isinstance(self.text, list):
                raise ValueError("Text should be a list for batch processing.")
            self.text = self.text * self.parallel_sample_num
        elif self.input_ids is not None:
            if not isinstance(self.input_ids, list) or not isinstance(
                self.input_ids[0], list
            ):
                raise ValueError(
                    "input_ids should be a list of lists for batch processing."
                )
            self.input_ids = self.input_ids * self.parallel_sample_num
        elif self.input_embeds is not None:
            if not isinstance(self.input_embeds, list):
                raise ValueError("input_embeds should be a list for batch processing.")
            self.input_embeds = self.input_embeds * self.parallel_sample_num

    def _normalize_lora_paths(self, num):
        """Normalize LoRA paths for batch processing."""
        if self.lora_path is not None:
            if isinstance(self.lora_path, str):
                self.lora_path = [self.lora_path] * num
            elif isinstance(self.lora_path, list):
                self.lora_path = self.lora_path * self.parallel_sample_num
286
            else:
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
                raise ValueError("lora_path should be a list or a string.")

    def _normalize_image_data(self, num):
        """Normalize image data for batch processing."""
        if self.image_data is None:
            self.image_data = [None] * num
        elif not isinstance(self.image_data, list):
            # Single image, convert to list of single-image lists
            self.image_data = [[self.image_data]] * num
            self.modalities = ["image"] * num
        elif isinstance(self.image_data, list):
            if len(self.image_data) != self.batch_size:
                raise ValueError(
                    "The length of image_data should be equal to the batch size."
                )

            self.modalities = []
            if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
                # Already a list of lists, keep as is
                for i in range(len(self.image_data)):
                    if self.image_data[i] is None or self.image_data[i] == [None]:
                        self.modalities.append(None)
                    elif len(self.image_data[i]) == 1:
                        self.modalities.append("image")
                    elif len(self.image_data[i]) > 1:
                        self.modalities.append("multi-images")
313
314
315
                    else:
                        # Ensure len(self.modalities) == len(self.image_data)
                        self.modalities.append(None)
316
                # Expand parallel_sample_num
317
318
                self.image_data = self.image_data * self.parallel_sample_num
                self.modalities = self.modalities * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
319
            else:
320
321
322
323
324
325
                # List of images for a batch, wrap each in a list
                wrapped_images = [[img] for img in self.image_data]
                # Expand for parallel sampling
                self.image_data = wrapped_images * self.parallel_sample_num
                self.modalities = ["image"] * num

326
327
328
329
330
331
332
333
334
    def _normalize_video_data(self, num):
        """Normalize video data for batch processing."""
        if self.video_data is None:
            self.video_data = [None] * num
        elif not isinstance(self.video_data, list):
            self.video_data = [self.video_data] * num
        elif isinstance(self.video_data, list):
            self.video_data = self.video_data * self.parallel_sample_num

335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    def _normalize_audio_data(self, num):
        """Normalize audio data for batch processing."""
        if self.audio_data is None:
            self.audio_data = [None] * num
        elif not isinstance(self.audio_data, list):
            self.audio_data = [self.audio_data] * num
        elif isinstance(self.audio_data, list):
            self.audio_data = self.audio_data * self.parallel_sample_num

    def _normalize_sampling_params(self, num):
        """Normalize sampling parameters for batch processing."""
        if self.sampling_params is None:
            self.sampling_params = [{}] * num
        elif isinstance(self.sampling_params, dict):
            self.sampling_params = [self.sampling_params] * num
        else:  # Already a list
            self.sampling_params = self.sampling_params * self.parallel_sample_num

    def _normalize_rid(self, num):
        """Normalize request IDs for batch processing."""
        if self.rid is None:
            self.rid = [uuid.uuid4().hex for _ in range(num)]
357
358
359
360
        elif isinstance(self.rid, str):
            new_rids = [f"{self.rid}_{i}" for i in range(num)]
            self.rid = new_rids
        elif isinstance(self.rid, list):
361
362
363
            # Note: the length of rid shall be the same as the batch_size,
            # as the rid would be expanded for parallel sampling in tokenizer_manager
            if len(self.rid) != self.batch_size:
364
365
366
367
368
                raise ValueError(
                    "The specified rids length mismatch with the batch_size for batch processing."
                )
        else:
            raise ValueError("The rid should be a string or a list of strings.")
369
370
371
372
373
374
375
376
377
378

    def _normalize_logprob_params(self, num):
        """Normalize logprob-related parameters for batch processing."""

        # Helper function to normalize a parameter
        def normalize_param(param, default_value, param_name):
            if param is None:
                return [default_value] * num
            elif not isinstance(param, list):
                return [param] * num
379
            else:
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
                if self.parallel_sample_num > 1:
                    raise ValueError(
                        f"Cannot use list {param_name} with parallel_sample_num > 1"
                    )
                return param

        # Normalize each logprob parameter
        self.return_logprob = normalize_param(
            self.return_logprob, False, "return_logprob"
        )
        self.logprob_start_len = normalize_param(
            self.logprob_start_len, -1, "logprob_start_len"
        )
        self.top_logprobs_num = normalize_param(
            self.top_logprobs_num, 0, "top_logprobs_num"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
396

397
398
399
400
401
402
403
404
405
406
407
408
409
        # Handle token_ids_logprob specially due to its nested structure
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = [None] * num
        elif not isinstance(self.token_ids_logprob, list):
            self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
        elif not isinstance(self.token_ids_logprob[0], list):
            self.token_ids_logprob = [
                copy.deepcopy(self.token_ids_logprob) for _ in range(num)
            ]
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list token_ids_logprob with parallel_sample_num > 1"
            )
410

411
412
413
414
415
416
417
418
419
420
    def _normalize_custom_logit_processor(self, num):
        """Normalize custom logit processor for batch processing."""
        if self.custom_logit_processor is None:
            self.custom_logit_processor = [None] * num
        elif not isinstance(self.custom_logit_processor, list):
            self.custom_logit_processor = [self.custom_logit_processor] * num
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list custom_logit_processor with parallel_sample_num > 1"
            )
421

422
423
    def _validate_session_params(self):
        """Validate that session parameters are properly formatted."""
424
        if self.session_params is not None:
425
            if not isinstance(self.session_params, dict) and not isinstance(
426
                self.session_params[0], dict
427
428
            ):
                raise ValueError("Session params must be a dict or a list of dicts.")
429

430
    def regenerate_rid(self):
431
        """Generate a new request ID and return it."""
432
433
434
435
436
437
438
        self.rid = uuid.uuid4().hex
        return self.rid

    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
439
440
441
            input_embeds=(
                self.input_embeds[i] if self.input_embeds is not None else None
            ),
442
            image_data=self.image_data[i],
443
            video_data=self.video_data[i],
Mick's avatar
Mick committed
444
            audio_data=self.audio_data[i],
445
446
447
448
449
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
450
            token_ids_logprob=self.token_ids_logprob[i],
451
452
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
453
            log_metrics=self.log_metrics,
454
455
            modalities=self.modalities[i] if self.modalities else None,
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
456
457
458
459
460
            custom_logit_processor=(
                self.custom_logit_processor[i]
                if self.custom_logit_processor is not None
                else None
            ),
461
462
463
464
465
            return_hidden_states=(
                self.return_hidden_states[i]
                if isinstance(self.return_hidden_states, list)
                else self.return_hidden_states
            ),
466
            # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
467
468
469
            bootstrap_host=(
                self.bootstrap_host[i] if self.bootstrap_host is not None else None
            ),
470
471
472
            bootstrap_port=(
                self.bootstrap_port[i] if self.bootstrap_port is not None else None
            ),
473
474
475
            bootstrap_room=(
                self.bootstrap_room[i] if self.bootstrap_room is not None else None
            ),
476
477
478
            data_parallel_rank=(
                self.data_parallel_rank if self.data_parallel_rank is not None else None
            ),
479
480
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
481
482
483

@dataclass
class TokenizedGenerateReqInput:
484
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
485
    rid: str
486
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
487
    input_text: str
488
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
489
    input_ids: List[int]
Mick's avatar
Mick committed
490
491
    # The multimodal inputs
    mm_inputs: dict
492
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
493
    sampling_params: SamplingParams
494
    # Whether to return the logprobs
495
    return_logprob: bool
496
    # If return logprobs, the start location in the prompt for returning logprobs.
497
    logprob_start_len: int
498
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
499
    top_logprobs_num: int
500
501
    # If return logprobs, the token id to return logprob for
    token_ids_logprob: List[int]
502
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
503
504
    stream: bool

505
    # LoRA related
506
    lora_id: Optional[str] = None  # None means just use the base model
Rin Intachuen's avatar
Rin Intachuen committed
507
508
    # The input embeds
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
509

510
511
    # Session info for continual prompting
    session_params: Optional[SessionParams] = None
512

513
514
515
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
516
517
    custom_logit_processor: Optional[str] = None

518
519
520
    # Whether to return hidden states
    return_hidden_states: bool = False

521
522
    # For disaggregated inference
    bootstrap_host: Optional[str] = None
523
    bootstrap_port: Optional[int] = None
524
525
    bootstrap_room: Optional[int] = None

526
527
528
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

529
530
531
    # For dp balance
    dp_balance_id: int = -1

Lianmin Zheng's avatar
Lianmin Zheng committed
532

533
534
535
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
woodx's avatar
woodx committed
536
    text: Optional[Union[List[List[str]], List[str], str]] = None
537
538
539
540
541
542
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
543
    image_data: Optional[MultimodalDataInputFormat] = None
544
    # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
545
    video_data: Optional[MultimodalDataInputFormat] = None
546
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
547
    audio_data: Optional[MultimodalDataInputFormat] = None
548
549
550
551
552
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
553
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Rin Intachuen's avatar
Rin Intachuen committed
554
555
    # Dummy input embeds for compatibility
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
556
557
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
558
559
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
woodx's avatar
woodx committed
560
561
    # For cross-encoder requests
    is_cross_encoder_request: bool = False
562

563
    def normalize_batch_and_arguments(self):
564
565
566
567
568
569
570
571
572
        # at least one of text, input_ids, or image should be provided
        if self.text is None and self.input_ids is None and self.image_data is None:
            raise ValueError(
                "At least one of text, input_ids, or image should be provided"
            )

        # text and input_ids cannot be provided at the same time
        if self.text is not None and self.input_ids is not None:
            raise ValueError("text and input_ids cannot be provided at the same time")
573

574
        # Derive the batch size
575
576
577
578
        self.batch_size = 0
        self.is_single = True

        # check the batch size of text
579
        if self.text is not None:
580
581
            if isinstance(self.text, list):
                self.batch_size += len(self.text)
582
                self.is_single = False
583
            else:
584
585
586
587
588
589
                self.batch_size += 1

        # check the batch size of input_ids
        if self.input_ids is not None:
            if isinstance(self.input_ids[0], list):
                self.batch_size += len(self.input_ids)
590
                self.is_single = False
591
            else:
592
593
                self.batch_size += 1

594
        # Fill in default arguments
595
        if self.is_single:
596
597
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
598
            if self.sampling_params is None:
599
                self.sampling_params = {}
600
            self.sampling_params["max_new_tokens"] = 0
601
602
603
604
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
605
606
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
607
            if self.sampling_params is None:
608
609
                self.sampling_params = [{}] * self.batch_size
            for i in range(self.batch_size):
610
                self.sampling_params[i]["max_new_tokens"] = 0
611

612
613
614
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid
615

616
    def contains_mm_input(self) -> bool:
617
618
619
620
621
        return (
            has_valid_data(self.image_data)
            or has_valid_data(self.video_data)
            or has_valid_data(self.audio_data)
        )
622

623
    def __getitem__(self, i):
woodx's avatar
woodx committed
624
625
626
627
628
629
630
631
        if self.is_cross_encoder_request:
            return EmbeddingReqInput(
                text=[self.text[i]] if self.text is not None else None,
                sampling_params=self.sampling_params[i],
                rid=self.rid[i],
                is_cross_encoder_request=True,
            )

632
633
634
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
635
            image_data=self.image_data[i] if self.image_data is not None else None,
636
637
            audio_data=self.audio_data[i] if self.audio_data is not None else None,
            video_data=self.video_data[i] if self.video_data is not None else None,
638
639
640
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
        )
641
642
643


@dataclass
644
class TokenizedEmbeddingReqInput:
645
646
647
648
649
650
    # The request id
    rid: str
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
651
652
    # The image inputs
    image_inputs: dict
woodx's avatar
woodx committed
653
654
    # The token type ids
    token_type_ids: List[int]
655
656
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams
657
658
    # For dp balance
    dp_balance_id: int = -1
659
660


Lianmin Zheng's avatar
Lianmin Zheng committed
661
662
@dataclass
class BatchTokenIDOut:
663
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
664
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
665
666
667
    # The finish reason
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
668
    decoded_texts: List[str]
669
670
    decode_ids: List[int]
    read_offsets: List[int]
671
    # Only used when `--skip-tokenizer-init` is on
672
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
673
    # Detokenization configs
Lianmin Zheng's avatar
Lianmin Zheng committed
674
    skip_special_tokens: List[bool]
675
    spaces_between_special_tokens: List[bool]
676
    no_stop_trim: List[bool]
677

Lianmin Zheng's avatar
Lianmin Zheng committed
678
679
680
681
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
682
683
    spec_verify_ct: List[int]

Lianmin Zheng's avatar
Lianmin Zheng committed
684
685
686
687
688
689
690
691
692
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
693
694
695
696
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Lianmin Zheng's avatar
Lianmin Zheng committed
697

698
    # Hidden states
699
700
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
701

702
703
704
705
@dataclass
class BatchMultimodalDecodeReq:
    # The request id
    rids: List[str]
706
707
708
709
710
711
    finished_reasons: List[BaseFinishReason]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
712
713


Lianmin Zheng's avatar
Lianmin Zheng committed
714
715
@dataclass
class BatchStrOut:
716
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
717
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
718
719
    # The finish reason
    finished_reasons: List[dict]
720
    # The output decoded strings
721
    output_strs: List[str]
722
723
    # The token ids
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
724
725
726
727
728

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
729
    spec_verify_ct: List[int]
730

Lianmin Zheng's avatar
Lianmin Zheng committed
731
732
733
734
735
736
737
738
739
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
740
741
742
743
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Liangsheng Yin's avatar
Liangsheng Yin committed
744

745
    # Hidden states
746
747
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
748

749
750
751
752
@dataclass
class BatchMultimodalOut:
    # The request id
    rids: List[str]
753
754
755
756
757
758
759
760
761
    # The finish reason
    finished_reasons: List[dict]
    # The outputs
    outputs: List[List[Dict]]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
762
763


764
765
@dataclass
class BatchEmbeddingOut:
766
    # The request id
767
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
768
769
    # The finish reason
    finished_reasons: List[BaseFinishReason]
770
    # The output embedding
771
    embeddings: List[List[float]]
Lianmin Zheng's avatar
Lianmin Zheng committed
772
773
    # Token counts
    prompt_tokens: List[int]
774
    cached_tokens: List[int]
775
776


Liangsheng Yin's avatar
Liangsheng Yin committed
777
@dataclass
778
class FlushCacheReqInput:
Liangsheng Yin's avatar
Liangsheng Yin committed
779
    pass
Cody Yu's avatar
Cody Yu committed
780

781

782
783
784
785
786
@dataclass
class FlushCacheReqOutput:
    success: bool


787
@dataclass
Chayenne's avatar
Chayenne committed
788
class UpdateWeightFromDiskReqInput:
789
790
791
792
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None
793
794
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
795
796
797


@dataclass
Chayenne's avatar
Chayenne committed
798
class UpdateWeightFromDiskReqOutput:
799
800
    success: bool
    message: str
801
802
    # Number of paused requests during weight sync.
    num_paused_requests: Optional[int] = 0
803
804


805
806
@dataclass
class UpdateWeightsFromDistributedReqInput:
807
808
809
810
811
812
813
    names: List[str]
    dtypes: List[str]
    shapes: List[List[int]]
    # The group name
    group_name: str = "weight_update_group"
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
814
815
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
816
817
818
819
820
821
822
823


@dataclass
class UpdateWeightsFromDistributedReqOutput:
    success: bool
    message: str


824
825
@dataclass
class UpdateWeightsFromTensorReqInput:
826
827
828
829
830
831
832
833
834
835
836
    """Update model weights from tensor input.

    - Tensors are serialized for transmission
    - Data is structured in JSON for easy transmission over HTTP
    """

    serialized_named_tensors: List[Union[str, bytes]]
    # Optional format specification for loading
    load_format: Optional[str] = None
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
837
838
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
839
840
841
842
843
844
845
846


@dataclass
class UpdateWeightsFromTensorReqOutput:
    success: bool
    message: str


847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
@dataclass
class InitWeightsUpdateGroupReqInput:
    # The master address
    master_address: str
    # The master port
    master_port: int
    # The rank offset
    rank_offset: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_update_group"
    # The backend
    backend: str = "nccl"


@dataclass
class InitWeightsUpdateGroupReqOutput:
    success: bool
    message: str


869
870
871
872
873
874
875
876
877
878
879
@dataclass
class GetWeightsByNameReqInput:
    name: str
    truncate_size: int = 100


@dataclass
class GetWeightsByNameReqOutput:
    parameter: list


880
881
@dataclass
class ReleaseMemoryOccupationReqInput:
882
883
884
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
885
886
887
888
889
890
891
892
893


@dataclass
class ReleaseMemoryOccupationReqOutput:
    pass


@dataclass
class ResumeMemoryOccupationReqInput:
894
895
896
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
897
898
899
900
901
902
903


@dataclass
class ResumeMemoryOccupationReqOutput:
    pass


904
905
906
907
908
909
910
911
912
913
@dataclass
class SlowDownReqInput:
    forward_sleep_time: Optional[float]


@dataclass
class SlowDownReqOutput:
    pass


914
915
@dataclass
class AbortReq:
916
    # The request id
917
918
919
    rid: str = ""
    # Whether to abort all requests
    abort_all: bool = False
920
921
    # The finished reason data
    finished_reason: Optional[Dict[str, Any]] = None
922
923


924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
@dataclass
class GetInternalStateReq:
    pass


@dataclass
class GetInternalStateReqOutput:
    internal_state: Dict[Any, Any]


@dataclass
class SetInternalStateReq:
    server_args: Dict[str, Any]


@dataclass
class SetInternalStateReqOutput:
    updated: bool
    server_args: Dict[str, Any]


@dataclass
class ProfileReqInput:
    # The output directory
    output_dir: Optional[str] = None
    # If set, it profile as many as this number of steps.
    # If it is set, profiling is automatically stopped after this step, and
    # the caller doesn't need to run stop_profile.
952
    start_step: Optional[int] = None
953
    num_steps: Optional[int] = None
954
955
    activities: Optional[List[str]] = None
    profile_by_stage: bool = False
956
957
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
958
959
960


class ProfileReqType(Enum):
961
962
    START_PROFILE = 1
    STOP_PROFILE = 2
963
964


965
966
967
968
@dataclass
class ProfileReq:
    type: ProfileReqType
    output_dir: Optional[str] = None
969
    start_step: Optional[int] = None
970
971
    num_steps: Optional[int] = None
    activities: Optional[List[str]] = None
972
    profile_by_stage: bool = False
973
974
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
975
    profile_id: Optional[str] = None
976
977
978
979
980
981
982
983


@dataclass
class ProfileReqOutput:
    success: bool
    message: str


984
985
986
@dataclass
class ConfigureLoggingReq:
    log_requests: Optional[bool] = None
987
    log_requests_level: Optional[int] = None
988
989
990
991
    dump_requests_folder: Optional[str] = None
    dump_requests_threshold: Optional[int] = None


992
993
994
@dataclass
class OpenSessionReqInput:
    capacity_of_str_len: int
995
    session_id: Optional[str] = None
996
997
998
999
1000
1001
1002
1003
1004


@dataclass
class CloseSessionReqInput:
    session_id: str


@dataclass
class OpenSessionReqOutput:
1005
1006
    session_id: Optional[str]
    success: bool
YAMY's avatar
YAMY committed
1007
1008


1009
1010
1011
1012
1013
@dataclass
class HealthCheckOutput:
    pass


1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
class ExpertDistributionReq(Enum):
    START_RECORD = 1
    STOP_RECORD = 2
    DUMP_RECORD = 3


@dataclass
class ExpertDistributionReqOutput:
    pass


YAMY's avatar
YAMY committed
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
@dataclass
class Function:
    description: Optional[str] = None
    name: Optional[str] = None
    parameters: Optional[object] = None


@dataclass
class Tool:
    function: Function
    type: Optional[str] = "function"


@dataclass
1039
class ParseFunctionCallReq:
YAMY's avatar
YAMY committed
1040
1041
1042
1043
1044
1045
1046
    text: str  # The text to parse.
    tools: List[Tool] = field(
        default_factory=list
    )  # A list of available function tools (name, parameters, etc.).
    tool_call_parser: Optional[str] = (
        None  # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
    )
1047
1048


Xihuai Wang's avatar
Xihuai Wang committed
1049
1050
1051
1052
1053
1054
@dataclass
class SeparateReasoningReqInput:
    text: str  # The text to parse.
    reasoning_parser: str  # Specify the parser type, e.g., "deepseek-r1".


1055
1056
1057
1058
@dataclass
class VertexGenerateReqInput:
    instances: List[dict]
    parameters: Optional[dict] = None
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070


@dataclass
class RpcReqInput:
    method: str
    parameters: Optional[Dict] = None


@dataclass
class RpcReqOutput:
    success: bool
    message: str
1071
1072
1073
1074
1075
1076
1077
1078


@dataclass
class LoadLoRAAdapterReqInput:
    # The name of the lora module to newly loaded.
    lora_name: str
    # The path of loading.
    lora_path: str
1079
1080
1081
1082
1083
1084
1085
1086
1087
    # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
    lora_id: Optional[str] = None

    def to_ref(self) -> LoRARef:
        return LoRARef(
            lora_id=self.lora_id,
            lora_name=self.lora_name,
            lora_path=self.lora_path,
        )
1088
1089
1090
1091
1092
1093


@dataclass
class UnloadLoRAAdapterReqInput:
    # The name of lora module to unload.
    lora_name: str
1094
1095
1096
1097
1098
1099
1100
1101
    # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
    lora_id: Optional[str] = None

    def to_ref(self) -> LoRARef:
        return LoRARef(
            lora_id=self.lora_id,
            lora_name=self.lora_name,
        )
1102
1103
1104
1105
1106
1107


@dataclass
class LoRAUpdateResult:
    success: bool
    error_message: Optional[str] = None
1108
    loaded_adapters: Optional[Dict[str, LoRARef]] = None
1109
1110
1111


LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
fzyzcjy's avatar
fzyzcjy committed
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121


class BlockReqType(Enum):
    BLOCK = 1
    UNBLOCK = 2


@dataclass
class BlockReqInput:
    type: BlockReqType