io_struct.py 34.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""
15
The definition of objects transferred between different
16
processes (TokenizerManager, DetokenizerManager, Controller).
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
"""

19
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import uuid
YAMY's avatar
YAMY committed
21
from dataclasses import dataclass, field
22
from enum import Enum
23
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
24

25
from sglang.srt.managers.schedule_batch import BaseFinishReason
26
from sglang.srt.multimodal.mm_utils import has_valid_data
27
from sglang.srt.sampling.sampling_params import SamplingParams
28

29
# Handle serialization of Image for pydantic
30
31
32
33
if TYPE_CHECKING:
    from PIL.Image import Image
else:
    Image = Any
34

Lianmin Zheng's avatar
Lianmin Zheng committed
35

36
37
38
39
40
41
42
43
@dataclass
class SessionParams:
    id: Optional[str] = None
    rid: Optional[str] = None
    offset: Optional[int] = None
    replace: Optional[bool] = None


44
45
46
47
AudioDataItem = Union[str, Dict]
ImageDataItem = Union[Image, str, Dict]


Lianmin Zheng's avatar
Lianmin Zheng committed
48
49
@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
50
    # The input prompt. It can be a single prompt or a batch of prompts.
51
    text: Optional[Union[List[str], str]] = None
Rin Intachuen's avatar
Rin Intachuen committed
52
    # The token ids for text; one can specify either text or input_ids
53
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Rin Intachuen's avatar
Rin Intachuen committed
54
55
    # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
56
57
58
59
60
61
62
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
    image_data: Optional[
63
        Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]
64
65
    ] = None
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
66
    audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
67
    # The sampling_params. See descriptions below.
68
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
69
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
70
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
71
    # Whether to return logprobs.
72
    return_logprob: Optional[Union[List[bool], bool]] = None
73
    # If return logprobs, the start location in the prompt for returning logprobs.
74
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
75
    logprob_start_len: Optional[Union[List[int], int]] = None
76
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
77
    top_logprobs_num: Optional[Union[List[int], int]] = None
78
79
    # If return logprobs, the token ids to return logprob for.
    token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
80
    # Whether to detokenize tokens in text in the returned logprobs.
81
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
82
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
83
    stream: bool = False
84
85
86
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True

87
88
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
89
    # The path to the LoRA
90
91
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

92
93
    # Session info for continual prompting
    session_params: Optional[Union[List[Dict], Dict]] = None
94

95
96
97
98
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
    custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
99

100
    # Whether to return hidden states
101
    return_hidden_states: Union[List[bool], bool] = False
102

103
    # For disaggregated inference
104
    bootstrap_host: Optional[Union[List[str], str]] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
105
    bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
106
    bootstrap_room: Optional[Union[List[int], int]] = None
107

108
109
110
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

111
112
113
    def contains_mm_input(self) -> bool:
        return has_valid_data(self.image_data) or has_valid_data(self.audio_data)

114
    def normalize_batch_and_arguments(self):
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        """
        Normalize the batch size and arguments for the request.

        This method resolves various input formats and ensures all parameters
        are properly formatted as either single values or batches depending on the input.
        It also handles parallel sampling expansion and sets default values for
        unspecified parameters.

        Raises:
            ValueError: If inputs are not properly specified (e.g., none or all of
                       text, input_ids, input_embeds are provided)
        """
        self._validate_inputs()
        self._determine_batch_size()
        self._handle_parallel_sampling()

        if self.is_single:
            self._normalize_single_inputs()
        else:
            self._normalize_batch_inputs()

        self._validate_session_params()

    def _validate_inputs(self):
        """Validate that the input configuration is valid."""
Rin Intachuen's avatar
Rin Intachuen committed
140
141
142
143
144
145
        if (
            self.text is None and self.input_ids is None and self.input_embeds is None
        ) or (
            self.text is not None
            and self.input_ids is not None
            and self.input_embeds is not None
146
        ):
Rin Intachuen's avatar
Rin Intachuen committed
147
148
149
            raise ValueError(
                "Either text, input_ids or input_embeds should be provided."
            )
150

151
152
    def _determine_batch_size(self):
        """Determine if this is a single example or a batch and the batch size."""
153
154
155
156
157
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
158
                self.is_single = False
159
                self.batch_size = len(self.text)
Rin Intachuen's avatar
Rin Intachuen committed
160
161
            self.input_embeds = None
        elif self.input_ids is not None:
Yinghai Lu's avatar
Yinghai Lu committed
162
163
            if len(self.input_ids) == 0:
                raise ValueError("input_ids cannot be empty.")
164
165
166
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
167
            else:
168
                self.is_single = False
169
                self.batch_size = len(self.input_ids)
Rin Intachuen's avatar
Rin Intachuen committed
170
171
172
173
174
175
            self.input_embeds = None
        else:
            if isinstance(self.input_embeds[0][0], float):
                self.is_single = True
                self.batch_size = 1
            else:
176
                self.is_single = False
Rin Intachuen's avatar
Rin Intachuen committed
177
                self.batch_size = len(self.input_embeds)
178

179
180
181
    def _handle_parallel_sampling(self):
        """Handle parallel sampling parameters and adjust batch size if needed."""
        # Determine parallel sample count
182
183
        if self.sampling_params is None:
            self.parallel_sample_num = 1
184
            return
185
        elif isinstance(self.sampling_params, dict):
186
187
188
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
189
190
191
192
193
            for sampling_params in self.sampling_params:
                if self.parallel_sample_num != sampling_params.get("n", 1):
                    raise ValueError(
                        "The parallel_sample_num should be the same for all samples in sample params."
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
194

195
        # If using parallel sampling with a single example, convert to batch
196
197
198
199
200
201
202
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    def _normalize_single_inputs(self):
        """Normalize inputs for a single example."""
        if self.sampling_params is None:
            self.sampling_params = {}
        if self.rid is None:
            self.rid = uuid.uuid4().hex
        if self.return_logprob is None:
            self.return_logprob = False
        if self.logprob_start_len is None:
            self.logprob_start_len = -1
        if self.top_logprobs_num is None:
            self.top_logprobs_num = 0
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = None

    def _normalize_batch_inputs(self):
        """Normalize inputs for a batch of examples, including parallel sampling expansion."""
        # Calculate expanded batch size
        if self.parallel_sample_num == 1:
            num = self.batch_size
Lianmin Zheng's avatar
Lianmin Zheng committed
223
        else:
224
225
226
227
228
            # Expand parallel_sample_num
            num = self.batch_size * self.parallel_sample_num

        # Expand input based on type
        self._expand_inputs(num)
229
        self._normalize_rid(num)
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        self._normalize_lora_paths(num)
        self._normalize_image_data(num)
        self._normalize_audio_data(num)
        self._normalize_sampling_params(num)
        self._normalize_logprob_params(num)
        self._normalize_custom_logit_processor(num)

    def _expand_inputs(self, num):
        """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
        if self.text is not None:
            if not isinstance(self.text, list):
                raise ValueError("Text should be a list for batch processing.")
            self.text = self.text * self.parallel_sample_num
        elif self.input_ids is not None:
            if not isinstance(self.input_ids, list) or not isinstance(
                self.input_ids[0], list
            ):
                raise ValueError(
                    "input_ids should be a list of lists for batch processing."
                )
            self.input_ids = self.input_ids * self.parallel_sample_num
        elif self.input_embeds is not None:
            if not isinstance(self.input_embeds, list):
                raise ValueError("input_embeds should be a list for batch processing.")
            self.input_embeds = self.input_embeds * self.parallel_sample_num

    def _normalize_lora_paths(self, num):
        """Normalize LoRA paths for batch processing."""
        if self.lora_path is not None:
            if isinstance(self.lora_path, str):
                self.lora_path = [self.lora_path] * num
            elif isinstance(self.lora_path, list):
                self.lora_path = self.lora_path * self.parallel_sample_num
263
            else:
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
                raise ValueError("lora_path should be a list or a string.")

    def _normalize_image_data(self, num):
        """Normalize image data for batch processing."""
        if self.image_data is None:
            self.image_data = [None] * num
        elif not isinstance(self.image_data, list):
            # Single image, convert to list of single-image lists
            self.image_data = [[self.image_data]] * num
            self.modalities = ["image"] * num
        elif isinstance(self.image_data, list):
            if len(self.image_data) != self.batch_size:
                raise ValueError(
                    "The length of image_data should be equal to the batch size."
                )

            self.modalities = []
            if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
                # Already a list of lists, keep as is
                for i in range(len(self.image_data)):
                    if self.image_data[i] is None or self.image_data[i] == [None]:
                        self.modalities.append(None)
                    elif len(self.image_data[i]) == 1:
                        self.modalities.append("image")
                    elif len(self.image_data[i]) > 1:
                        self.modalities.append("multi-images")
290
                # Expand parallel_sample_num
291
292
                self.image_data = self.image_data * self.parallel_sample_num
                self.modalities = self.modalities * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
293
            else:
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
                # List of images for a batch, wrap each in a list
                wrapped_images = [[img] for img in self.image_data]
                # Expand for parallel sampling
                self.image_data = wrapped_images * self.parallel_sample_num
                self.modalities = ["image"] * num

    def _normalize_audio_data(self, num):
        """Normalize audio data for batch processing."""
        if self.audio_data is None:
            self.audio_data = [None] * num
        elif not isinstance(self.audio_data, list):
            self.audio_data = [self.audio_data] * num
        elif isinstance(self.audio_data, list):
            self.audio_data = self.audio_data * self.parallel_sample_num

    def _normalize_sampling_params(self, num):
        """Normalize sampling parameters for batch processing."""
        if self.sampling_params is None:
            self.sampling_params = [{}] * num
        elif isinstance(self.sampling_params, dict):
            self.sampling_params = [self.sampling_params] * num
        else:  # Already a list
            self.sampling_params = self.sampling_params * self.parallel_sample_num

    def _normalize_rid(self, num):
        """Normalize request IDs for batch processing."""
        if self.rid is None:
            self.rid = [uuid.uuid4().hex for _ in range(num)]
322
323
324
325
326
327
328
329
330
331
        elif isinstance(self.rid, str):
            new_rids = [f"{self.rid}_{i}" for i in range(num)]
            self.rid = new_rids
        elif isinstance(self.rid, list):
            if len(self.rid) != num:
                raise ValueError(
                    "The specified rids length mismatch with the batch_size for batch processing."
                )
        else:
            raise ValueError("The rid should be a string or a list of strings.")
332
333
334
335
336
337
338
339
340
341

    def _normalize_logprob_params(self, num):
        """Normalize logprob-related parameters for batch processing."""

        # Helper function to normalize a parameter
        def normalize_param(param, default_value, param_name):
            if param is None:
                return [default_value] * num
            elif not isinstance(param, list):
                return [param] * num
342
            else:
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
                if self.parallel_sample_num > 1:
                    raise ValueError(
                        f"Cannot use list {param_name} with parallel_sample_num > 1"
                    )
                return param

        # Normalize each logprob parameter
        self.return_logprob = normalize_param(
            self.return_logprob, False, "return_logprob"
        )
        self.logprob_start_len = normalize_param(
            self.logprob_start_len, -1, "logprob_start_len"
        )
        self.top_logprobs_num = normalize_param(
            self.top_logprobs_num, 0, "top_logprobs_num"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
359

360
361
362
363
364
365
366
367
368
369
370
371
372
        # Handle token_ids_logprob specially due to its nested structure
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = [None] * num
        elif not isinstance(self.token_ids_logprob, list):
            self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
        elif not isinstance(self.token_ids_logprob[0], list):
            self.token_ids_logprob = [
                copy.deepcopy(self.token_ids_logprob) for _ in range(num)
            ]
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list token_ids_logprob with parallel_sample_num > 1"
            )
373

374
375
376
377
378
379
380
381
382
383
    def _normalize_custom_logit_processor(self, num):
        """Normalize custom logit processor for batch processing."""
        if self.custom_logit_processor is None:
            self.custom_logit_processor = [None] * num
        elif not isinstance(self.custom_logit_processor, list):
            self.custom_logit_processor = [self.custom_logit_processor] * num
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list custom_logit_processor with parallel_sample_num > 1"
            )
384

385
386
    def _validate_session_params(self):
        """Validate that session parameters are properly formatted."""
387
        if self.session_params is not None:
388
            if not isinstance(self.session_params, dict) and not isinstance(
389
                self.session_params[0], dict
390
391
            ):
                raise ValueError("Session params must be a dict or a list of dicts.")
392

393
    def regenerate_rid(self):
394
        """Generate a new request ID and return it."""
395
396
397
398
399
400
401
402
        self.rid = uuid.uuid4().hex
        return self.rid

    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
            image_data=self.image_data[i],
Mick's avatar
Mick committed
403
            audio_data=self.audio_data[i],
404
405
406
407
408
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
409
            token_ids_logprob=self.token_ids_logprob[i],
410
411
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
412
            log_metrics=self.log_metrics,
413
414
            modalities=self.modalities[i] if self.modalities else None,
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
415
416
417
418
419
            custom_logit_processor=(
                self.custom_logit_processor[i]
                if self.custom_logit_processor is not None
                else None
            ),
420
421
422
423
424
            return_hidden_states=(
                self.return_hidden_states[i]
                if isinstance(self.return_hidden_states, list)
                else self.return_hidden_states
            ),
425
            # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
426
427
428
            bootstrap_host=(
                self.bootstrap_host[i] if self.bootstrap_host is not None else None
            ),
429
430
431
            bootstrap_port=(
                self.bootstrap_port[i] if self.bootstrap_port is not None else None
            ),
432
433
434
            bootstrap_room=(
                self.bootstrap_room[i] if self.bootstrap_room is not None else None
            ),
435
436
437
            data_parallel_rank=(
                self.data_parallel_rank if self.data_parallel_rank is not None else None
            ),
438
439
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
440
441
442

@dataclass
class TokenizedGenerateReqInput:
443
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
444
    rid: str
445
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
446
    input_text: str
447
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
448
    input_ids: List[int]
Mick's avatar
Mick committed
449
450
    # The multimodal inputs
    mm_inputs: dict
451
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
452
    sampling_params: SamplingParams
453
    # Whether to return the logprobs
454
    return_logprob: bool
455
    # If return logprobs, the start location in the prompt for returning logprobs.
456
    logprob_start_len: int
457
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
458
    top_logprobs_num: int
459
460
    # If return logprobs, the token id to return logprob for
    token_ids_logprob: List[int]
461
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
462
463
    stream: bool

464
465
    # LoRA related
    lora_path: Optional[str] = None  # None means just use the base model
Rin Intachuen's avatar
Rin Intachuen committed
466
467
    # The input embeds
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
468

469
470
    # Session info for continual prompting
    session_params: Optional[SessionParams] = None
471

472
473
474
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
475
476
    custom_logit_processor: Optional[str] = None

477
478
479
    # Whether to return hidden states
    return_hidden_states: bool = False

480
481
    # For disaggregated inference
    bootstrap_host: Optional[str] = None
482
    bootstrap_port: Optional[int] = None
483
484
    bootstrap_room: Optional[int] = None

485
486
487
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
488

489
490
491
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
woodx's avatar
woodx committed
492
    text: Optional[Union[List[List[str]], List[str], str]] = None
493
494
495
496
497
498
499
500
501
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
    image_data: Optional[
        Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
    ] = None
502
503
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
    audio_data: Optional[Union[List[str], str]] = None
504
505
506
507
508
509
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
    sampling_params: Union[List[Dict], Dict] = None
Rin Intachuen's avatar
Rin Intachuen committed
510
511
    # Dummy input embeds for compatibility
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
512
513
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
514
515
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
woodx's avatar
woodx committed
516
517
    # For cross-encoder requests
    is_cross_encoder_request: bool = False
518

519
    def normalize_batch_and_arguments(self):
520
521
522
523
524
525
526
527
528
        # at least one of text, input_ids, or image should be provided
        if self.text is None and self.input_ids is None and self.image_data is None:
            raise ValueError(
                "At least one of text, input_ids, or image should be provided"
            )

        # text and input_ids cannot be provided at the same time
        if self.text is not None and self.input_ids is not None:
            raise ValueError("text and input_ids cannot be provided at the same time")
529

530
        # Derive the batch size
531
532
533
534
        self.batch_size = 0
        self.is_single = True

        # check the batch size of text
535
        if self.text is not None:
536
537
            if isinstance(self.text, list):
                self.batch_size += len(self.text)
538
                self.is_single = False
539
            else:
540
541
542
543
544
545
                self.batch_size += 1

        # check the batch size of input_ids
        if self.input_ids is not None:
            if isinstance(self.input_ids[0], list):
                self.batch_size += len(self.input_ids)
546
                self.is_single = False
547
            else:
548
549
                self.batch_size += 1

550
        # Fill in default arguments
551
        if self.is_single:
552
553
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
554
            if self.sampling_params is None:
555
                self.sampling_params = {}
556
            self.sampling_params["max_new_tokens"] = 0
557
558
559
560
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
561
562
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
563
            if self.sampling_params is None:
564
565
                self.sampling_params = [{}] * self.batch_size
            for i in range(self.batch_size):
566
                self.sampling_params[i]["max_new_tokens"] = 0
567

568
569
570
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid
571

572
573
574
    def contains_mm_input(self) -> bool:
        return has_valid_data(self.image_data) or has_valid_data(self.audio_data)

575
    def __getitem__(self, i):
woodx's avatar
woodx committed
576
577
578
579
580
581
582
583
584
585
        if self.is_cross_encoder_request:
            return EmbeddingReqInput(
                text=[self.text[i]] if self.text is not None else None,
                input_ids=None,
                image_data=None,
                sampling_params=self.sampling_params[i],
                rid=self.rid[i],
                is_cross_encoder_request=True,
            )

586
587
588
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
589
            image_data=self.image_data[i] if self.image_data is not None else None,
590
591
592
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
        )
593
594
595


@dataclass
596
class TokenizedEmbeddingReqInput:
597
598
599
600
601
602
    # The request id
    rid: str
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
603
604
    # The image inputs
    image_inputs: dict
woodx's avatar
woodx committed
605
606
    # The token type ids
    token_type_ids: List[int]
607
608
609
610
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams


Lianmin Zheng's avatar
Lianmin Zheng committed
611
612
@dataclass
class BatchTokenIDOut:
613
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
614
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
615
616
617
    # The finish reason
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
618
    decoded_texts: List[str]
619
620
    decode_ids: List[int]
    read_offsets: List[int]
621
    # Only used when `--skip-tokenizer-init` is on
622
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
623
    # Detokenization configs
Lianmin Zheng's avatar
Lianmin Zheng committed
624
    skip_special_tokens: List[bool]
625
    spaces_between_special_tokens: List[bool]
626
    no_stop_trim: List[bool]
627

Lianmin Zheng's avatar
Lianmin Zheng committed
628
629
630
631
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
632
633
    spec_verify_ct: List[int]

Lianmin Zheng's avatar
Lianmin Zheng committed
634
635
636
637
638
639
640
641
642
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
643
644
645
646
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Lianmin Zheng's avatar
Lianmin Zheng committed
647

648
    # Hidden states
649
650
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
651

652
653
654
655
@dataclass
class BatchMultimodalDecodeReq:
    # The request id
    rids: List[str]
656
657
658
659
660
661
    finished_reasons: List[BaseFinishReason]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
662
663


Lianmin Zheng's avatar
Lianmin Zheng committed
664
665
@dataclass
class BatchStrOut:
666
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
667
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
668
669
    # The finish reason
    finished_reasons: List[dict]
670
    # The output decoded strings
671
    output_strs: List[str]
672
673
    # The token ids
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
674
675
676
677
678

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
679
    spec_verify_ct: List[int]
680

Lianmin Zheng's avatar
Lianmin Zheng committed
681
682
683
684
685
686
687
688
689
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
690
691
692
693
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Liangsheng Yin's avatar
Liangsheng Yin committed
694

695
    # Hidden states
696
697
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
698

699
700
701
702
@dataclass
class BatchMultimodalOut:
    # The request id
    rids: List[str]
703
704
705
706
707
708
709
710
711
    # The finish reason
    finished_reasons: List[dict]
    # The outputs
    outputs: List[List[Dict]]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
712
713


714
715
@dataclass
class BatchEmbeddingOut:
716
    # The request id
717
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
718
719
    # The finish reason
    finished_reasons: List[BaseFinishReason]
720
    # The output embedding
721
    embeddings: List[List[float]]
Lianmin Zheng's avatar
Lianmin Zheng committed
722
723
    # Token counts
    prompt_tokens: List[int]
724
    cached_tokens: List[int]
725
726


Liangsheng Yin's avatar
Liangsheng Yin committed
727
@dataclass
728
class FlushCacheReqInput:
Liangsheng Yin's avatar
Liangsheng Yin committed
729
    pass
Cody Yu's avatar
Cody Yu committed
730

731

732
733
734
735
736
@dataclass
class FlushCacheReqOutput:
    success: bool


737
@dataclass
Chayenne's avatar
Chayenne committed
738
class UpdateWeightFromDiskReqInput:
739
740
741
742
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None
743
744
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
745
746
747


@dataclass
Chayenne's avatar
Chayenne committed
748
class UpdateWeightFromDiskReqOutput:
749
750
    success: bool
    message: str
751
752
    # Number of paused requests during weight sync.
    num_paused_requests: Optional[int] = 0
753
754


755
756
@dataclass
class UpdateWeightsFromDistributedReqInput:
757
758
759
760
761
762
763
    names: List[str]
    dtypes: List[str]
    shapes: List[List[int]]
    # The group name
    group_name: str = "weight_update_group"
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
764
765
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
766
767
768
769
770
771
772
773


@dataclass
class UpdateWeightsFromDistributedReqOutput:
    success: bool
    message: str


774
775
@dataclass
class UpdateWeightsFromTensorReqInput:
776
777
778
779
780
781
782
783
784
785
786
    """Update model weights from tensor input.

    - Tensors are serialized for transmission
    - Data is structured in JSON for easy transmission over HTTP
    """

    serialized_named_tensors: List[Union[str, bytes]]
    # Optional format specification for loading
    load_format: Optional[str] = None
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
787
788
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
789
790
791
792
793
794
795
796


@dataclass
class UpdateWeightsFromTensorReqOutput:
    success: bool
    message: str


797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
@dataclass
class InitWeightsUpdateGroupReqInput:
    # The master address
    master_address: str
    # The master port
    master_port: int
    # The rank offset
    rank_offset: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_update_group"
    # The backend
    backend: str = "nccl"


@dataclass
class InitWeightsUpdateGroupReqOutput:
    success: bool
    message: str


819
820
821
822
823
824
825
826
827
828
829
@dataclass
class GetWeightsByNameReqInput:
    name: str
    truncate_size: int = 100


@dataclass
class GetWeightsByNameReqOutput:
    parameter: list


830
831
@dataclass
class ReleaseMemoryOccupationReqInput:
832
833
834
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
835
836
837
838
839
840
841
842
843


@dataclass
class ReleaseMemoryOccupationReqOutput:
    pass


@dataclass
class ResumeMemoryOccupationReqInput:
844
845
846
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
847
848
849
850
851
852
853


@dataclass
class ResumeMemoryOccupationReqOutput:
    pass


854
855
856
857
858
859
860
861
862
863
@dataclass
class SlowDownReqInput:
    forward_sleep_time: Optional[float]


@dataclass
class SlowDownReqOutput:
    pass


864
865
@dataclass
class AbortReq:
866
    # The request id
867
868
869
    rid: str = ""
    # Whether to abort all requests
    abort_all: bool = False
870
871


872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
@dataclass
class GetInternalStateReq:
    pass


@dataclass
class GetInternalStateReqOutput:
    internal_state: Dict[Any, Any]


@dataclass
class SetInternalStateReq:
    server_args: Dict[str, Any]


@dataclass
class SetInternalStateReqOutput:
    updated: bool
    server_args: Dict[str, Any]


@dataclass
class ProfileReqInput:
    # The output directory
    output_dir: Optional[str] = None
    # If set, it profile as many as this number of steps.
    # If it is set, profiling is automatically stopped after this step, and
    # the caller doesn't need to run stop_profile.
    num_steps: Optional[int] = None
901
902
    activities: Optional[List[str]] = None
    profile_by_stage: bool = False
903
904
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
905
906
907


class ProfileReqType(Enum):
908
909
    START_PROFILE = 1
    STOP_PROFILE = 2
910
911


912
913
914
915
916
917
class ExpertDistributionReq(Enum):
    START_RECORD = 1
    STOP_RECORD = 2
    DUMP_RECORD = 3


918
919
920
921
922
@dataclass
class ExpertDistributionReqOutput:
    pass


923
924
925
926
927
928
@dataclass
class ProfileReq:
    type: ProfileReqType
    output_dir: Optional[str] = None
    num_steps: Optional[int] = None
    activities: Optional[List[str]] = None
929
    profile_by_stage: bool = False
930
931
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
932
    profile_id: Optional[str] = None
933
934
935
936
937
938
939
940


@dataclass
class ProfileReqOutput:
    success: bool
    message: str


941
942
943
@dataclass
class ConfigureLoggingReq:
    log_requests: Optional[bool] = None
944
    log_requests_level: Optional[int] = None
945
946
947
948
    dump_requests_folder: Optional[str] = None
    dump_requests_threshold: Optional[int] = None


949
950
951
@dataclass
class OpenSessionReqInput:
    capacity_of_str_len: int
952
    session_id: Optional[str] = None
953
954
955
956
957
958
959
960
961


@dataclass
class CloseSessionReqInput:
    session_id: str


@dataclass
class OpenSessionReqOutput:
962
963
    session_id: Optional[str]
    success: bool
YAMY's avatar
YAMY committed
964
965


966
967
968
969
970
@dataclass
class HealthCheckOutput:
    pass


YAMY's avatar
YAMY committed
971
972
973
974
975
976
977
978
979
980
981
982
983
984
@dataclass
class Function:
    description: Optional[str] = None
    name: Optional[str] = None
    parameters: Optional[object] = None


@dataclass
class Tool:
    function: Function
    type: Optional[str] = "function"


@dataclass
985
class ParseFunctionCallReq:
YAMY's avatar
YAMY committed
986
987
988
989
990
991
992
    text: str  # The text to parse.
    tools: List[Tool] = field(
        default_factory=list
    )  # A list of available function tools (name, parameters, etc.).
    tool_call_parser: Optional[str] = (
        None  # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
    )
993
994


Xihuai Wang's avatar
Xihuai Wang committed
995
996
997
998
999
1000
@dataclass
class SeparateReasoningReqInput:
    text: str  # The text to parse.
    reasoning_parser: str  # Specify the parser type, e.g., "deepseek-r1".


1001
1002
1003
1004
@dataclass
class VertexGenerateReqInput:
    instances: List[dict]
    parameters: Optional[dict] = None
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016


@dataclass
class RpcReqInput:
    method: str
    parameters: Optional[Dict] = None


@dataclass
class RpcReqOutput:
    success: bool
    message: str
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040


@dataclass
class LoadLoRAAdapterReqInput:
    # The name of the lora module to newly loaded.
    lora_name: str
    # The path of loading.
    lora_path: str


@dataclass
class UnloadLoRAAdapterReqInput:
    # The name of lora module to unload.
    lora_name: str


@dataclass
class LoRAUpdateResult:
    success: bool
    error_message: Optional[str] = None
    loaded_adapters: Dict[str, str] = field(default_factory=dict)


LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult