io_struct.py 37.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""
15
The definition of objects transferred between different
16
processes (TokenizerManager, DetokenizerManager, Scheduler).
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
"""

19
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import uuid
YAMY's avatar
YAMY committed
21
from dataclasses import dataclass, field
22
from enum import Enum
23
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
24

25
from sglang.srt.lora.lora_registry import LoRARef
26
from sglang.srt.managers.schedule_batch import BaseFinishReason
27
from sglang.srt.multimodal.mm_utils import has_valid_data
28
from sglang.srt.sampling.sampling_params import SamplingParams
29

30
# Handle serialization of Image for pydantic
31
32
33
34
if TYPE_CHECKING:
    from PIL.Image import Image
else:
    Image = Any
35

Lianmin Zheng's avatar
Lianmin Zheng committed
36

37
38
39
40
41
42
@dataclass
class SessionParams:
    id: Optional[str] = None
    rid: Optional[str] = None
    offset: Optional[int] = None
    replace: Optional[bool] = None
43
    drop_previous_output: Optional[bool] = None
44
45


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Type definitions for multimodal input data
# Individual data item types for each modality
ImageDataInputItem = Union[Image, str, Dict]
AudioDataInputItem = Union[str, Dict]
VideoDataInputItem = Union[str, Dict]
# Union type for any multimodal data item
MultimodalDataInputItem = Union[
    ImageDataInputItem, VideoDataInputItem, AudioDataInputItem
]
# Format types supporting single items, lists, or nested lists for batch processing
MultimodalDataInputFormat = Union[
    List[List[MultimodalDataInputItem]],
    List[MultimodalDataInputItem],
    MultimodalDataInputItem,
]
61
62


Lianmin Zheng's avatar
Lianmin Zheng committed
63
64
@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
65
    # The input prompt. It can be a single prompt or a batch of prompts.
66
    text: Optional[Union[List[str], str]] = None
Rin Intachuen's avatar
Rin Intachuen committed
67
    # The token ids for text; one can specify either text or input_ids
68
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Rin Intachuen's avatar
Rin Intachuen committed
69
70
    # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
71
72
73
74
75
76
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
77
    image_data: Optional[MultimodalDataInputFormat] = None
78
    # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
79
80
81
    video_data: Optional[MultimodalDataInputFormat] = None
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
    audio_data: Optional[MultimodalDataInputFormat] = None
82
    # The sampling_params. See descriptions below.
83
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
84
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
85
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
86
    # Whether to return logprobs.
87
    return_logprob: Optional[Union[List[bool], bool]] = None
88
    # If return logprobs, the start location in the prompt for returning logprobs.
89
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
90
    logprob_start_len: Optional[Union[List[int], int]] = None
91
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
92
    top_logprobs_num: Optional[Union[List[int], int]] = None
93
94
    # If return logprobs, the token ids to return logprob for.
    token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
95
    # Whether to detokenize tokens in text in the returned logprobs.
96
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
97
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
98
    stream: bool = False
99
100
101
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True

102
103
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
104
    # The path to the LoRA
105
106
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

107
108
    # Session info for continual prompting
    session_params: Optional[Union[List[Dict], Dict]] = None
109

110
111
112
113
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
    custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
114

115
    # Whether to return hidden states
116
    return_hidden_states: Union[List[bool], bool] = False
117

118
    # For disaggregated inference
119
    bootstrap_host: Optional[Union[List[str], str]] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
120
    bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
121
    bootstrap_room: Optional[Union[List[int], int]] = None
122

123
124
125
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

126
    def contains_mm_input(self) -> bool:
127
128
129
130
131
        return (
            has_valid_data(self.image_data)
            or has_valid_data(self.video_data)
            or has_valid_data(self.audio_data)
        )
132

133
    def normalize_batch_and_arguments(self):
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        """
        Normalize the batch size and arguments for the request.

        This method resolves various input formats and ensures all parameters
        are properly formatted as either single values or batches depending on the input.
        It also handles parallel sampling expansion and sets default values for
        unspecified parameters.

        Raises:
            ValueError: If inputs are not properly specified (e.g., none or all of
                       text, input_ids, input_embeds are provided)
        """
        self._validate_inputs()
        self._determine_batch_size()
        self._handle_parallel_sampling()

        if self.is_single:
            self._normalize_single_inputs()
        else:
            self._normalize_batch_inputs()

    def _validate_inputs(self):
        """Validate that the input configuration is valid."""
Rin Intachuen's avatar
Rin Intachuen committed
157
158
159
160
161
162
        if (
            self.text is None and self.input_ids is None and self.input_embeds is None
        ) or (
            self.text is not None
            and self.input_ids is not None
            and self.input_embeds is not None
163
        ):
Rin Intachuen's avatar
Rin Intachuen committed
164
165
166
            raise ValueError(
                "Either text, input_ids or input_embeds should be provided."
            )
167

168
169
    def _determine_batch_size(self):
        """Determine if this is a single example or a batch and the batch size."""
170
171
172
173
174
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
175
                self.is_single = False
176
                self.batch_size = len(self.text)
Rin Intachuen's avatar
Rin Intachuen committed
177
178
            self.input_embeds = None
        elif self.input_ids is not None:
Yinghai Lu's avatar
Yinghai Lu committed
179
180
            if len(self.input_ids) == 0:
                raise ValueError("input_ids cannot be empty.")
181
182
183
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
184
            else:
185
                self.is_single = False
186
                self.batch_size = len(self.input_ids)
Rin Intachuen's avatar
Rin Intachuen committed
187
188
189
190
191
192
            self.input_embeds = None
        else:
            if isinstance(self.input_embeds[0][0], float):
                self.is_single = True
                self.batch_size = 1
            else:
193
                self.is_single = False
Rin Intachuen's avatar
Rin Intachuen committed
194
                self.batch_size = len(self.input_embeds)
195

196
197
198
    def _handle_parallel_sampling(self):
        """Handle parallel sampling parameters and adjust batch size if needed."""
        # Determine parallel sample count
199
200
        if self.sampling_params is None:
            self.parallel_sample_num = 1
201
            return
202
        elif isinstance(self.sampling_params, dict):
203
204
205
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
206
207
208
209
210
            for sampling_params in self.sampling_params:
                if self.parallel_sample_num != sampling_params.get("n", 1):
                    raise ValueError(
                        "The parallel_sample_num should be the same for all samples in sample params."
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
211

212
        # If using parallel sampling with a single example, convert to batch
213
214
215
216
217
218
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]
219
220
            if self.input_embeds is not None:
                self.input_embeds = [self.input_embeds]
221

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    def _normalize_single_inputs(self):
        """Normalize inputs for a single example."""
        if self.sampling_params is None:
            self.sampling_params = {}
        if self.rid is None:
            self.rid = uuid.uuid4().hex
        if self.return_logprob is None:
            self.return_logprob = False
        if self.logprob_start_len is None:
            self.logprob_start_len = -1
        if self.top_logprobs_num is None:
            self.top_logprobs_num = 0
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = None

    def _normalize_batch_inputs(self):
        """Normalize inputs for a batch of examples, including parallel sampling expansion."""
        # Calculate expanded batch size
        if self.parallel_sample_num == 1:
            num = self.batch_size
Lianmin Zheng's avatar
Lianmin Zheng committed
242
        else:
243
244
245
246
247
            # Expand parallel_sample_num
            num = self.batch_size * self.parallel_sample_num

        # Expand input based on type
        self._expand_inputs(num)
248
        self._normalize_rid(num)
249
250
        self._normalize_lora_paths(num)
        self._normalize_image_data(num)
251
        self._normalize_video_data(num)
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        self._normalize_audio_data(num)
        self._normalize_sampling_params(num)
        self._normalize_logprob_params(num)
        self._normalize_custom_logit_processor(num)

    def _expand_inputs(self, num):
        """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
        if self.text is not None:
            if not isinstance(self.text, list):
                raise ValueError("Text should be a list for batch processing.")
            self.text = self.text * self.parallel_sample_num
        elif self.input_ids is not None:
            if not isinstance(self.input_ids, list) or not isinstance(
                self.input_ids[0], list
            ):
                raise ValueError(
                    "input_ids should be a list of lists for batch processing."
                )
            self.input_ids = self.input_ids * self.parallel_sample_num
        elif self.input_embeds is not None:
            if not isinstance(self.input_embeds, list):
                raise ValueError("input_embeds should be a list for batch processing.")
            self.input_embeds = self.input_embeds * self.parallel_sample_num

    def _normalize_lora_paths(self, num):
        """Normalize LoRA paths for batch processing."""
        if self.lora_path is not None:
            if isinstance(self.lora_path, str):
                self.lora_path = [self.lora_path] * num
            elif isinstance(self.lora_path, list):
                self.lora_path = self.lora_path * self.parallel_sample_num
283
            else:
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
                raise ValueError("lora_path should be a list or a string.")

    def _normalize_image_data(self, num):
        """Normalize image data for batch processing."""
        if self.image_data is None:
            self.image_data = [None] * num
        elif not isinstance(self.image_data, list):
            # Single image, convert to list of single-image lists
            self.image_data = [[self.image_data]] * num
            self.modalities = ["image"] * num
        elif isinstance(self.image_data, list):
            if len(self.image_data) != self.batch_size:
                raise ValueError(
                    "The length of image_data should be equal to the batch size."
                )

            self.modalities = []
            if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
                # Already a list of lists, keep as is
                for i in range(len(self.image_data)):
                    if self.image_data[i] is None or self.image_data[i] == [None]:
                        self.modalities.append(None)
                    elif len(self.image_data[i]) == 1:
                        self.modalities.append("image")
                    elif len(self.image_data[i]) > 1:
                        self.modalities.append("multi-images")
310
311
312
                    else:
                        # Ensure len(self.modalities) == len(self.image_data)
                        self.modalities.append(None)
313
                # Expand parallel_sample_num
314
315
                self.image_data = self.image_data * self.parallel_sample_num
                self.modalities = self.modalities * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
316
            else:
317
318
319
320
321
322
                # List of images for a batch, wrap each in a list
                wrapped_images = [[img] for img in self.image_data]
                # Expand for parallel sampling
                self.image_data = wrapped_images * self.parallel_sample_num
                self.modalities = ["image"] * num

323
324
325
326
327
328
329
330
331
    def _normalize_video_data(self, num):
        """Normalize video data for batch processing."""
        if self.video_data is None:
            self.video_data = [None] * num
        elif not isinstance(self.video_data, list):
            self.video_data = [self.video_data] * num
        elif isinstance(self.video_data, list):
            self.video_data = self.video_data * self.parallel_sample_num

332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    def _normalize_audio_data(self, num):
        """Normalize audio data for batch processing."""
        if self.audio_data is None:
            self.audio_data = [None] * num
        elif not isinstance(self.audio_data, list):
            self.audio_data = [self.audio_data] * num
        elif isinstance(self.audio_data, list):
            self.audio_data = self.audio_data * self.parallel_sample_num

    def _normalize_sampling_params(self, num):
        """Normalize sampling parameters for batch processing."""
        if self.sampling_params is None:
            self.sampling_params = [{}] * num
        elif isinstance(self.sampling_params, dict):
            self.sampling_params = [self.sampling_params] * num
        else:  # Already a list
            self.sampling_params = self.sampling_params * self.parallel_sample_num

    def _normalize_rid(self, num):
        """Normalize request IDs for batch processing."""
        if self.rid is None:
            self.rid = [uuid.uuid4().hex for _ in range(num)]
354
355
356
357
        elif isinstance(self.rid, str):
            new_rids = [f"{self.rid}_{i}" for i in range(num)]
            self.rid = new_rids
        elif isinstance(self.rid, list):
358
359
360
            # Note: the length of rid shall be the same as the batch_size,
            # as the rid would be expanded for parallel sampling in tokenizer_manager
            if len(self.rid) != self.batch_size:
361
362
363
364
365
                raise ValueError(
                    "The specified rids length mismatch with the batch_size for batch processing."
                )
        else:
            raise ValueError("The rid should be a string or a list of strings.")
366
367
368
369
370
371
372
373
374
375

    def _normalize_logprob_params(self, num):
        """Normalize logprob-related parameters for batch processing."""

        # Helper function to normalize a parameter
        def normalize_param(param, default_value, param_name):
            if param is None:
                return [default_value] * num
            elif not isinstance(param, list):
                return [param] * num
376
            else:
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
                if self.parallel_sample_num > 1:
                    raise ValueError(
                        f"Cannot use list {param_name} with parallel_sample_num > 1"
                    )
                return param

        # Normalize each logprob parameter
        self.return_logprob = normalize_param(
            self.return_logprob, False, "return_logprob"
        )
        self.logprob_start_len = normalize_param(
            self.logprob_start_len, -1, "logprob_start_len"
        )
        self.top_logprobs_num = normalize_param(
            self.top_logprobs_num, 0, "top_logprobs_num"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
393

394
395
396
397
398
399
400
401
402
403
404
405
406
        # Handle token_ids_logprob specially due to its nested structure
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = [None] * num
        elif not isinstance(self.token_ids_logprob, list):
            self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
        elif not isinstance(self.token_ids_logprob[0], list):
            self.token_ids_logprob = [
                copy.deepcopy(self.token_ids_logprob) for _ in range(num)
            ]
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list token_ids_logprob with parallel_sample_num > 1"
            )
407

408
409
410
411
412
413
414
415
416
417
    def _normalize_custom_logit_processor(self, num):
        """Normalize custom logit processor for batch processing."""
        if self.custom_logit_processor is None:
            self.custom_logit_processor = [None] * num
        elif not isinstance(self.custom_logit_processor, list):
            self.custom_logit_processor = [self.custom_logit_processor] * num
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list custom_logit_processor with parallel_sample_num > 1"
            )
418

419
420
    def _validate_session_params(self):
        """Validate that session parameters are properly formatted."""
421
        if self.session_params is not None:
422
            if not isinstance(self.session_params, dict) and not isinstance(
423
                self.session_params[0], dict
424
425
            ):
                raise ValueError("Session params must be a dict or a list of dicts.")
426

427
    def regenerate_rid(self):
428
        """Generate a new request ID and return it."""
429
430
431
432
433
434
435
        self.rid = uuid.uuid4().hex
        return self.rid

    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
436
437
438
            input_embeds=(
                self.input_embeds[i] if self.input_embeds is not None else None
            ),
439
            image_data=self.image_data[i],
440
            video_data=self.video_data[i],
Mick's avatar
Mick committed
441
            audio_data=self.audio_data[i],
442
443
444
445
446
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
447
            token_ids_logprob=self.token_ids_logprob[i],
448
449
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
450
            log_metrics=self.log_metrics,
451
452
            modalities=self.modalities[i] if self.modalities else None,
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
453
454
455
456
457
            custom_logit_processor=(
                self.custom_logit_processor[i]
                if self.custom_logit_processor is not None
                else None
            ),
458
459
460
461
462
            return_hidden_states=(
                self.return_hidden_states[i]
                if isinstance(self.return_hidden_states, list)
                else self.return_hidden_states
            ),
463
            # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
464
465
466
            bootstrap_host=(
                self.bootstrap_host[i] if self.bootstrap_host is not None else None
            ),
467
468
469
            bootstrap_port=(
                self.bootstrap_port[i] if self.bootstrap_port is not None else None
            ),
470
471
472
            bootstrap_room=(
                self.bootstrap_room[i] if self.bootstrap_room is not None else None
            ),
473
474
475
            data_parallel_rank=(
                self.data_parallel_rank if self.data_parallel_rank is not None else None
            ),
476
477
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
478
479
480

@dataclass
class TokenizedGenerateReqInput:
481
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
482
    rid: str
483
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
484
    input_text: str
485
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
486
    input_ids: List[int]
Mick's avatar
Mick committed
487
488
    # The multimodal inputs
    mm_inputs: dict
489
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
490
    sampling_params: SamplingParams
491
    # Whether to return the logprobs
492
    return_logprob: bool
493
    # If return logprobs, the start location in the prompt for returning logprobs.
494
    logprob_start_len: int
495
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
496
    top_logprobs_num: int
497
498
    # If return logprobs, the token id to return logprob for
    token_ids_logprob: List[int]
499
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
500
501
    stream: bool

502
503
    # LoRA related
    lora_path: Optional[str] = None  # None means just use the base model
Rin Intachuen's avatar
Rin Intachuen committed
504
505
    # The input embeds
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
506

507
508
    # Session info for continual prompting
    session_params: Optional[SessionParams] = None
509

510
511
512
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
513
514
    custom_logit_processor: Optional[str] = None

515
516
517
    # Whether to return hidden states
    return_hidden_states: bool = False

518
519
    # For disaggregated inference
    bootstrap_host: Optional[str] = None
520
    bootstrap_port: Optional[int] = None
521
522
    bootstrap_room: Optional[int] = None

523
524
525
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
526

527
528
529
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
woodx's avatar
woodx committed
530
    text: Optional[Union[List[List[str]], List[str], str]] = None
531
532
533
534
535
536
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
537
    image_data: Optional[MultimodalDataInputFormat] = None
538
    # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
539
    video_data: Optional[MultimodalDataInputFormat] = None
540
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
541
    audio_data: Optional[MultimodalDataInputFormat] = None
542
543
544
545
546
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
547
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Rin Intachuen's avatar
Rin Intachuen committed
548
549
    # Dummy input embeds for compatibility
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
550
551
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
552
553
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
woodx's avatar
woodx committed
554
555
    # For cross-encoder requests
    is_cross_encoder_request: bool = False
556

557
    def normalize_batch_and_arguments(self):
558
559
560
561
562
563
564
565
566
        # at least one of text, input_ids, or image should be provided
        if self.text is None and self.input_ids is None and self.image_data is None:
            raise ValueError(
                "At least one of text, input_ids, or image should be provided"
            )

        # text and input_ids cannot be provided at the same time
        if self.text is not None and self.input_ids is not None:
            raise ValueError("text and input_ids cannot be provided at the same time")
567

568
        # Derive the batch size
569
570
571
572
        self.batch_size = 0
        self.is_single = True

        # check the batch size of text
573
        if self.text is not None:
574
575
            if isinstance(self.text, list):
                self.batch_size += len(self.text)
576
                self.is_single = False
577
            else:
578
579
580
581
582
583
                self.batch_size += 1

        # check the batch size of input_ids
        if self.input_ids is not None:
            if isinstance(self.input_ids[0], list):
                self.batch_size += len(self.input_ids)
584
                self.is_single = False
585
            else:
586
587
                self.batch_size += 1

588
        # Fill in default arguments
589
        if self.is_single:
590
591
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
592
            if self.sampling_params is None:
593
                self.sampling_params = {}
594
            self.sampling_params["max_new_tokens"] = 0
595
596
597
598
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
599
600
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
601
            if self.sampling_params is None:
602
603
                self.sampling_params = [{}] * self.batch_size
            for i in range(self.batch_size):
604
                self.sampling_params[i]["max_new_tokens"] = 0
605

606
607
608
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid
609

610
    def contains_mm_input(self) -> bool:
611
612
613
614
615
        return (
            has_valid_data(self.image_data)
            or has_valid_data(self.video_data)
            or has_valid_data(self.audio_data)
        )
616

617
    def __getitem__(self, i):
woodx's avatar
woodx committed
618
619
620
621
622
623
624
625
        if self.is_cross_encoder_request:
            return EmbeddingReqInput(
                text=[self.text[i]] if self.text is not None else None,
                sampling_params=self.sampling_params[i],
                rid=self.rid[i],
                is_cross_encoder_request=True,
            )

626
627
628
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
629
            image_data=self.image_data[i] if self.image_data is not None else None,
630
631
            audio_data=self.audio_data[i] if self.audio_data is not None else None,
            video_data=self.video_data[i] if self.video_data is not None else None,
632
633
634
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
        )
635
636
637


@dataclass
638
class TokenizedEmbeddingReqInput:
639
640
641
642
643
644
    # The request id
    rid: str
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
645
646
    # The image inputs
    image_inputs: dict
woodx's avatar
woodx committed
647
648
    # The token type ids
    token_type_ids: List[int]
649
650
651
652
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams


Lianmin Zheng's avatar
Lianmin Zheng committed
653
654
@dataclass
class BatchTokenIDOut:
655
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
656
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
657
658
659
    # The finish reason
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
660
    decoded_texts: List[str]
661
662
    decode_ids: List[int]
    read_offsets: List[int]
663
    # Only used when `--skip-tokenizer-init` is on
664
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
665
    # Detokenization configs
Lianmin Zheng's avatar
Lianmin Zheng committed
666
    skip_special_tokens: List[bool]
667
    spaces_between_special_tokens: List[bool]
668
    no_stop_trim: List[bool]
669

Lianmin Zheng's avatar
Lianmin Zheng committed
670
671
672
673
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
674
675
    spec_verify_ct: List[int]

Lianmin Zheng's avatar
Lianmin Zheng committed
676
677
678
679
680
681
682
683
684
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
685
686
687
688
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Lianmin Zheng's avatar
Lianmin Zheng committed
689

690
    # Hidden states
691
692
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
693

694
695
696
697
@dataclass
class BatchMultimodalDecodeReq:
    # The request id
    rids: List[str]
698
699
700
701
702
703
    finished_reasons: List[BaseFinishReason]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
704
705


Lianmin Zheng's avatar
Lianmin Zheng committed
706
707
@dataclass
class BatchStrOut:
708
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
709
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
710
711
    # The finish reason
    finished_reasons: List[dict]
712
    # The output decoded strings
713
    output_strs: List[str]
714
715
    # The token ids
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
716
717
718
719
720

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
721
    spec_verify_ct: List[int]
722

Lianmin Zheng's avatar
Lianmin Zheng committed
723
724
725
726
727
728
729
730
731
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
732
733
734
735
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Liangsheng Yin's avatar
Liangsheng Yin committed
736

737
    # Hidden states
738
739
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
740

741
742
743
744
@dataclass
class BatchMultimodalOut:
    # The request id
    rids: List[str]
745
746
747
748
749
750
751
752
753
    # The finish reason
    finished_reasons: List[dict]
    # The outputs
    outputs: List[List[Dict]]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
754
755


756
757
@dataclass
class BatchEmbeddingOut:
758
    # The request id
759
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
760
761
    # The finish reason
    finished_reasons: List[BaseFinishReason]
762
    # The output embedding
763
    embeddings: List[List[float]]
Lianmin Zheng's avatar
Lianmin Zheng committed
764
765
    # Token counts
    prompt_tokens: List[int]
766
    cached_tokens: List[int]
767
768


Liangsheng Yin's avatar
Liangsheng Yin committed
769
@dataclass
770
class FlushCacheReqInput:
Liangsheng Yin's avatar
Liangsheng Yin committed
771
    pass
Cody Yu's avatar
Cody Yu committed
772

773

774
775
776
777
778
@dataclass
class FlushCacheReqOutput:
    success: bool


779
@dataclass
Chayenne's avatar
Chayenne committed
780
class UpdateWeightFromDiskReqInput:
781
782
783
784
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None
785
786
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
787
788
789


@dataclass
Chayenne's avatar
Chayenne committed
790
class UpdateWeightFromDiskReqOutput:
791
792
    success: bool
    message: str
793
794
    # Number of paused requests during weight sync.
    num_paused_requests: Optional[int] = 0
795
796


797
798
@dataclass
class UpdateWeightsFromDistributedReqInput:
799
800
801
802
803
804
805
    names: List[str]
    dtypes: List[str]
    shapes: List[List[int]]
    # The group name
    group_name: str = "weight_update_group"
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
806
807
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
808
809
810
811
812
813
814
815


@dataclass
class UpdateWeightsFromDistributedReqOutput:
    success: bool
    message: str


816
817
@dataclass
class UpdateWeightsFromTensorReqInput:
818
819
820
821
822
823
824
825
826
827
828
    """Update model weights from tensor input.

    - Tensors are serialized for transmission
    - Data is structured in JSON for easy transmission over HTTP
    """

    serialized_named_tensors: List[Union[str, bytes]]
    # Optional format specification for loading
    load_format: Optional[str] = None
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
829
830
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
831
832
833
834
835
836
837
838


@dataclass
class UpdateWeightsFromTensorReqOutput:
    success: bool
    message: str


839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
@dataclass
class InitWeightsUpdateGroupReqInput:
    # The master address
    master_address: str
    # The master port
    master_port: int
    # The rank offset
    rank_offset: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_update_group"
    # The backend
    backend: str = "nccl"


@dataclass
class InitWeightsUpdateGroupReqOutput:
    success: bool
    message: str


861
862
863
864
865
866
867
868
869
870
871
@dataclass
class GetWeightsByNameReqInput:
    name: str
    truncate_size: int = 100


@dataclass
class GetWeightsByNameReqOutput:
    parameter: list


872
873
@dataclass
class ReleaseMemoryOccupationReqInput:
874
875
876
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
877
878
879
880
881
882
883
884
885


@dataclass
class ReleaseMemoryOccupationReqOutput:
    pass


@dataclass
class ResumeMemoryOccupationReqInput:
886
887
888
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
889
890
891
892
893
894
895


@dataclass
class ResumeMemoryOccupationReqOutput:
    pass


896
897
898
899
900
901
902
903
904
905
@dataclass
class SlowDownReqInput:
    forward_sleep_time: Optional[float]


@dataclass
class SlowDownReqOutput:
    pass


906
907
@dataclass
class AbortReq:
908
    # The request id
909
910
911
    rid: str = ""
    # Whether to abort all requests
    abort_all: bool = False
912
913
    # The finished reason data
    finished_reason: Optional[Dict[str, Any]] = None
914
915


916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
@dataclass
class GetInternalStateReq:
    pass


@dataclass
class GetInternalStateReqOutput:
    internal_state: Dict[Any, Any]


@dataclass
class SetInternalStateReq:
    server_args: Dict[str, Any]


@dataclass
class SetInternalStateReqOutput:
    updated: bool
    server_args: Dict[str, Any]


@dataclass
class ProfileReqInput:
    # The output directory
    output_dir: Optional[str] = None
    # If set, it profile as many as this number of steps.
    # If it is set, profiling is automatically stopped after this step, and
    # the caller doesn't need to run stop_profile.
944
    start_step: Optional[int] = None
945
    num_steps: Optional[int] = None
946
947
    activities: Optional[List[str]] = None
    profile_by_stage: bool = False
948
949
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
950
951
952


class ProfileReqType(Enum):
953
954
    START_PROFILE = 1
    STOP_PROFILE = 2
955
956


957
958
959
960
@dataclass
class ProfileReq:
    type: ProfileReqType
    output_dir: Optional[str] = None
961
    start_step: Optional[int] = None
962
963
    num_steps: Optional[int] = None
    activities: Optional[List[str]] = None
964
    profile_by_stage: bool = False
965
966
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
967
    profile_id: Optional[str] = None
968
969
970
971
972
973
974
975


@dataclass
class ProfileReqOutput:
    success: bool
    message: str


976
977
978
@dataclass
class ConfigureLoggingReq:
    log_requests: Optional[bool] = None
979
    log_requests_level: Optional[int] = None
980
981
982
983
    dump_requests_folder: Optional[str] = None
    dump_requests_threshold: Optional[int] = None


984
985
986
@dataclass
class OpenSessionReqInput:
    capacity_of_str_len: int
987
    session_id: Optional[str] = None
988
989
990
991
992
993
994
995
996


@dataclass
class CloseSessionReqInput:
    session_id: str


@dataclass
class OpenSessionReqOutput:
997
998
    session_id: Optional[str]
    success: bool
YAMY's avatar
YAMY committed
999
1000


1001
1002
1003
1004
1005
@dataclass
class HealthCheckOutput:
    pass


1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
class ExpertDistributionReq(Enum):
    START_RECORD = 1
    STOP_RECORD = 2
    DUMP_RECORD = 3


@dataclass
class ExpertDistributionReqOutput:
    pass


YAMY's avatar
YAMY committed
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
@dataclass
class Function:
    description: Optional[str] = None
    name: Optional[str] = None
    parameters: Optional[object] = None


@dataclass
class Tool:
    function: Function
    type: Optional[str] = "function"


@dataclass
1031
class ParseFunctionCallReq:
YAMY's avatar
YAMY committed
1032
1033
1034
1035
1036
1037
1038
    text: str  # The text to parse.
    tools: List[Tool] = field(
        default_factory=list
    )  # A list of available function tools (name, parameters, etc.).
    tool_call_parser: Optional[str] = (
        None  # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
    )
1039
1040


Xihuai Wang's avatar
Xihuai Wang committed
1041
1042
1043
1044
1045
1046
@dataclass
class SeparateReasoningReqInput:
    text: str  # The text to parse.
    reasoning_parser: str  # Specify the parser type, e.g., "deepseek-r1".


1047
1048
1049
1050
@dataclass
class VertexGenerateReqInput:
    instances: List[dict]
    parameters: Optional[dict] = None
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062


@dataclass
class RpcReqInput:
    method: str
    parameters: Optional[Dict] = None


@dataclass
class RpcReqOutput:
    success: bool
    message: str
1063
1064
1065
1066
1067
1068
1069
1070


@dataclass
class LoadLoRAAdapterReqInput:
    # The name of the lora module to newly loaded.
    lora_name: str
    # The path of loading.
    lora_path: str
1071
1072
1073
1074
1075
1076
1077
1078
1079
    # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
    lora_id: Optional[str] = None

    def to_ref(self) -> LoRARef:
        return LoRARef(
            lora_id=self.lora_id,
            lora_name=self.lora_name,
            lora_path=self.lora_path,
        )
1080
1081
1082
1083
1084
1085


@dataclass
class UnloadLoRAAdapterReqInput:
    # The name of lora module to unload.
    lora_name: str
1086
1087
1088
1089
1090
1091
1092
1093
    # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
    lora_id: Optional[str] = None

    def to_ref(self) -> LoRARef:
        return LoRARef(
            lora_id=self.lora_id,
            lora_name=self.lora_name,
        )
1094
1095
1096
1097
1098
1099


@dataclass
class LoRAUpdateResult:
    success: bool
    error_message: Optional[str] = None
1100
    loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
1101
1102
1103


LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
fzyzcjy's avatar
fzyzcjy committed
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113


class BlockReqType(Enum):
    BLOCK = 1
    UNBLOCK = 2


@dataclass
class BlockReqInput:
    type: BlockReqType