io_struct.py 34.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""
15
The definition of objects transferred between different
16
processes (TokenizerManager, DetokenizerManager, Controller).
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
"""

19
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import uuid
YAMY's avatar
YAMY committed
21
from dataclasses import dataclass, field
22
from enum import Enum
23
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
24

25
from sglang.srt.managers.schedule_batch import BaseFinishReason
26
from sglang.srt.multimodal.mm_utils import has_valid_data
27
from sglang.srt.sampling.sampling_params import SamplingParams
28

29
# Handle serialization of Image for pydantic
30
31
32
33
if TYPE_CHECKING:
    from PIL.Image import Image
else:
    Image = Any
34

Lianmin Zheng's avatar
Lianmin Zheng committed
35

36
37
38
39
40
41
@dataclass
class SessionParams:
    id: Optional[str] = None
    rid: Optional[str] = None
    offset: Optional[int] = None
    replace: Optional[bool] = None
42
    drop_previous_output: Optional[bool] = None
43
44


45
46
47
48
AudioDataItem = Union[str, Dict]
ImageDataItem = Union[Image, str, Dict]


Lianmin Zheng's avatar
Lianmin Zheng committed
49
50
@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
51
    # The input prompt. It can be a single prompt or a batch of prompts.
52
    text: Optional[Union[List[str], str]] = None
Rin Intachuen's avatar
Rin Intachuen committed
53
    # The token ids for text; one can specify either text or input_ids
54
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Rin Intachuen's avatar
Rin Intachuen committed
55
56
    # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
57
58
59
60
61
62
63
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
    image_data: Optional[
64
        Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]
65
66
    ] = None
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
67
    audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
68
    # The sampling_params. See descriptions below.
69
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
70
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
71
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
72
    # Whether to return logprobs.
73
    return_logprob: Optional[Union[List[bool], bool]] = None
74
    # If return logprobs, the start location in the prompt for returning logprobs.
75
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
76
    logprob_start_len: Optional[Union[List[int], int]] = None
77
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
78
    top_logprobs_num: Optional[Union[List[int], int]] = None
79
80
    # If return logprobs, the token ids to return logprob for.
    token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
81
    # Whether to detokenize tokens in text in the returned logprobs.
82
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
83
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
84
    stream: bool = False
85
86
87
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True

88
89
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
90
    # The path to the LoRA
91
92
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

93
94
    # Session info for continual prompting
    session_params: Optional[Union[List[Dict], Dict]] = None
95

96
97
98
99
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
    custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
100

101
    # Whether to return hidden states
102
    return_hidden_states: Union[List[bool], bool] = False
103

104
    # For disaggregated inference
105
    bootstrap_host: Optional[Union[List[str], str]] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
106
    bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
107
    bootstrap_room: Optional[Union[List[int], int]] = None
108

109
110
111
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

112
113
114
    def contains_mm_input(self) -> bool:
        return has_valid_data(self.image_data) or has_valid_data(self.audio_data)

115
    def normalize_batch_and_arguments(self):
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        """
        Normalize the batch size and arguments for the request.

        This method resolves various input formats and ensures all parameters
        are properly formatted as either single values or batches depending on the input.
        It also handles parallel sampling expansion and sets default values for
        unspecified parameters.

        Raises:
            ValueError: If inputs are not properly specified (e.g., none or all of
                       text, input_ids, input_embeds are provided)
        """
        self._validate_inputs()
        self._determine_batch_size()
        self._handle_parallel_sampling()

        if self.is_single:
            self._normalize_single_inputs()
        else:
            self._normalize_batch_inputs()

        self._validate_session_params()

    def _validate_inputs(self):
        """Validate that the input configuration is valid."""
Rin Intachuen's avatar
Rin Intachuen committed
141
142
143
144
145
146
        if (
            self.text is None and self.input_ids is None and self.input_embeds is None
        ) or (
            self.text is not None
            and self.input_ids is not None
            and self.input_embeds is not None
147
        ):
Rin Intachuen's avatar
Rin Intachuen committed
148
149
150
            raise ValueError(
                "Either text, input_ids or input_embeds should be provided."
            )
151

152
153
    def _determine_batch_size(self):
        """Determine if this is a single example or a batch and the batch size."""
154
155
156
157
158
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
159
                self.is_single = False
160
                self.batch_size = len(self.text)
Rin Intachuen's avatar
Rin Intachuen committed
161
162
            self.input_embeds = None
        elif self.input_ids is not None:
Yinghai Lu's avatar
Yinghai Lu committed
163
164
            if len(self.input_ids) == 0:
                raise ValueError("input_ids cannot be empty.")
165
166
167
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
168
            else:
169
                self.is_single = False
170
                self.batch_size = len(self.input_ids)
Rin Intachuen's avatar
Rin Intachuen committed
171
172
173
174
175
176
            self.input_embeds = None
        else:
            if isinstance(self.input_embeds[0][0], float):
                self.is_single = True
                self.batch_size = 1
            else:
177
                self.is_single = False
Rin Intachuen's avatar
Rin Intachuen committed
178
                self.batch_size = len(self.input_embeds)
179

180
181
182
    def _handle_parallel_sampling(self):
        """Handle parallel sampling parameters and adjust batch size if needed."""
        # Determine parallel sample count
183
184
        if self.sampling_params is None:
            self.parallel_sample_num = 1
185
            return
186
        elif isinstance(self.sampling_params, dict):
187
188
189
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
190
191
192
193
194
            for sampling_params in self.sampling_params:
                if self.parallel_sample_num != sampling_params.get("n", 1):
                    raise ValueError(
                        "The parallel_sample_num should be the same for all samples in sample params."
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
195

196
        # If using parallel sampling with a single example, convert to batch
197
198
199
200
201
202
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]
203
204
            if self.input_embeds is not None:
                self.input_embeds = [self.input_embeds]
205

206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    def _normalize_single_inputs(self):
        """Normalize inputs for a single example."""
        if self.sampling_params is None:
            self.sampling_params = {}
        if self.rid is None:
            self.rid = uuid.uuid4().hex
        if self.return_logprob is None:
            self.return_logprob = False
        if self.logprob_start_len is None:
            self.logprob_start_len = -1
        if self.top_logprobs_num is None:
            self.top_logprobs_num = 0
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = None

    def _normalize_batch_inputs(self):
        """Normalize inputs for a batch of examples, including parallel sampling expansion."""
        # Calculate expanded batch size
        if self.parallel_sample_num == 1:
            num = self.batch_size
Lianmin Zheng's avatar
Lianmin Zheng committed
226
        else:
227
228
229
230
231
            # Expand parallel_sample_num
            num = self.batch_size * self.parallel_sample_num

        # Expand input based on type
        self._expand_inputs(num)
232
        self._normalize_rid(num)
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        self._normalize_lora_paths(num)
        self._normalize_image_data(num)
        self._normalize_audio_data(num)
        self._normalize_sampling_params(num)
        self._normalize_logprob_params(num)
        self._normalize_custom_logit_processor(num)

    def _expand_inputs(self, num):
        """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
        if self.text is not None:
            if not isinstance(self.text, list):
                raise ValueError("Text should be a list for batch processing.")
            self.text = self.text * self.parallel_sample_num
        elif self.input_ids is not None:
            if not isinstance(self.input_ids, list) or not isinstance(
                self.input_ids[0], list
            ):
                raise ValueError(
                    "input_ids should be a list of lists for batch processing."
                )
            self.input_ids = self.input_ids * self.parallel_sample_num
        elif self.input_embeds is not None:
            if not isinstance(self.input_embeds, list):
                raise ValueError("input_embeds should be a list for batch processing.")
            self.input_embeds = self.input_embeds * self.parallel_sample_num

    def _normalize_lora_paths(self, num):
        """Normalize LoRA paths for batch processing."""
        if self.lora_path is not None:
            if isinstance(self.lora_path, str):
                self.lora_path = [self.lora_path] * num
            elif isinstance(self.lora_path, list):
                self.lora_path = self.lora_path * self.parallel_sample_num
266
            else:
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
                raise ValueError("lora_path should be a list or a string.")

    def _normalize_image_data(self, num):
        """Normalize image data for batch processing."""
        if self.image_data is None:
            self.image_data = [None] * num
        elif not isinstance(self.image_data, list):
            # Single image, convert to list of single-image lists
            self.image_data = [[self.image_data]] * num
            self.modalities = ["image"] * num
        elif isinstance(self.image_data, list):
            if len(self.image_data) != self.batch_size:
                raise ValueError(
                    "The length of image_data should be equal to the batch size."
                )

            self.modalities = []
            if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
                # Already a list of lists, keep as is
                for i in range(len(self.image_data)):
                    if self.image_data[i] is None or self.image_data[i] == [None]:
                        self.modalities.append(None)
                    elif len(self.image_data[i]) == 1:
                        self.modalities.append("image")
                    elif len(self.image_data[i]) > 1:
                        self.modalities.append("multi-images")
293
                # Expand parallel_sample_num
294
295
                self.image_data = self.image_data * self.parallel_sample_num
                self.modalities = self.modalities * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
296
            else:
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
                # List of images for a batch, wrap each in a list
                wrapped_images = [[img] for img in self.image_data]
                # Expand for parallel sampling
                self.image_data = wrapped_images * self.parallel_sample_num
                self.modalities = ["image"] * num

    def _normalize_audio_data(self, num):
        """Normalize audio data for batch processing."""
        if self.audio_data is None:
            self.audio_data = [None] * num
        elif not isinstance(self.audio_data, list):
            self.audio_data = [self.audio_data] * num
        elif isinstance(self.audio_data, list):
            self.audio_data = self.audio_data * self.parallel_sample_num

    def _normalize_sampling_params(self, num):
        """Normalize sampling parameters for batch processing."""
        if self.sampling_params is None:
            self.sampling_params = [{}] * num
        elif isinstance(self.sampling_params, dict):
            self.sampling_params = [self.sampling_params] * num
        else:  # Already a list
            self.sampling_params = self.sampling_params * self.parallel_sample_num

    def _normalize_rid(self, num):
        """Normalize request IDs for batch processing."""
        if self.rid is None:
            self.rid = [uuid.uuid4().hex for _ in range(num)]
325
326
327
328
        elif isinstance(self.rid, str):
            new_rids = [f"{self.rid}_{i}" for i in range(num)]
            self.rid = new_rids
        elif isinstance(self.rid, list):
329
330
331
            # Note: the length of rid shall be the same as the batch_size,
            # as the rid would be expanded for parallel sampling in tokenizer_manager
            if len(self.rid) != self.batch_size:
332
333
334
335
336
                raise ValueError(
                    "The specified rids length mismatch with the batch_size for batch processing."
                )
        else:
            raise ValueError("The rid should be a string or a list of strings.")
337
338
339
340
341
342
343
344
345
346

    def _normalize_logprob_params(self, num):
        """Normalize logprob-related parameters for batch processing."""

        # Helper function to normalize a parameter
        def normalize_param(param, default_value, param_name):
            if param is None:
                return [default_value] * num
            elif not isinstance(param, list):
                return [param] * num
347
            else:
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
                if self.parallel_sample_num > 1:
                    raise ValueError(
                        f"Cannot use list {param_name} with parallel_sample_num > 1"
                    )
                return param

        # Normalize each logprob parameter
        self.return_logprob = normalize_param(
            self.return_logprob, False, "return_logprob"
        )
        self.logprob_start_len = normalize_param(
            self.logprob_start_len, -1, "logprob_start_len"
        )
        self.top_logprobs_num = normalize_param(
            self.top_logprobs_num, 0, "top_logprobs_num"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
364

365
366
367
368
369
370
371
372
373
374
375
376
377
        # Handle token_ids_logprob specially due to its nested structure
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = [None] * num
        elif not isinstance(self.token_ids_logprob, list):
            self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
        elif not isinstance(self.token_ids_logprob[0], list):
            self.token_ids_logprob = [
                copy.deepcopy(self.token_ids_logprob) for _ in range(num)
            ]
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list token_ids_logprob with parallel_sample_num > 1"
            )
378

379
380
381
382
383
384
385
386
387
388
    def _normalize_custom_logit_processor(self, num):
        """Normalize custom logit processor for batch processing."""
        if self.custom_logit_processor is None:
            self.custom_logit_processor = [None] * num
        elif not isinstance(self.custom_logit_processor, list):
            self.custom_logit_processor = [self.custom_logit_processor] * num
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list custom_logit_processor with parallel_sample_num > 1"
            )
389

390
391
    def _validate_session_params(self):
        """Validate that session parameters are properly formatted."""
392
        if self.session_params is not None:
393
            if not isinstance(self.session_params, dict) and not isinstance(
394
                self.session_params[0], dict
395
396
            ):
                raise ValueError("Session params must be a dict or a list of dicts.")
397

398
    def regenerate_rid(self):
399
        """Generate a new request ID and return it."""
400
401
402
403
404
405
406
        self.rid = uuid.uuid4().hex
        return self.rid

    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
407
408
409
            input_embeds=(
                self.input_embeds[i] if self.input_embeds is not None else None
            ),
410
            image_data=self.image_data[i],
Mick's avatar
Mick committed
411
            audio_data=self.audio_data[i],
412
413
414
415
416
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
417
            token_ids_logprob=self.token_ids_logprob[i],
418
419
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
420
            log_metrics=self.log_metrics,
421
422
            modalities=self.modalities[i] if self.modalities else None,
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
423
424
425
426
427
            custom_logit_processor=(
                self.custom_logit_processor[i]
                if self.custom_logit_processor is not None
                else None
            ),
428
429
430
431
432
            return_hidden_states=(
                self.return_hidden_states[i]
                if isinstance(self.return_hidden_states, list)
                else self.return_hidden_states
            ),
433
            # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
434
435
436
            bootstrap_host=(
                self.bootstrap_host[i] if self.bootstrap_host is not None else None
            ),
437
438
439
            bootstrap_port=(
                self.bootstrap_port[i] if self.bootstrap_port is not None else None
            ),
440
441
442
            bootstrap_room=(
                self.bootstrap_room[i] if self.bootstrap_room is not None else None
            ),
443
444
445
            data_parallel_rank=(
                self.data_parallel_rank if self.data_parallel_rank is not None else None
            ),
446
447
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
448
449
450

@dataclass
class TokenizedGenerateReqInput:
451
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
452
    rid: str
453
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
454
    input_text: str
455
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
456
    input_ids: List[int]
Mick's avatar
Mick committed
457
458
    # The multimodal inputs
    mm_inputs: dict
459
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
460
    sampling_params: SamplingParams
461
    # Whether to return the logprobs
462
    return_logprob: bool
463
    # If return logprobs, the start location in the prompt for returning logprobs.
464
    logprob_start_len: int
465
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
466
    top_logprobs_num: int
467
468
    # If return logprobs, the token id to return logprob for
    token_ids_logprob: List[int]
469
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
470
471
    stream: bool

472
473
    # LoRA related
    lora_path: Optional[str] = None  # None means just use the base model
Rin Intachuen's avatar
Rin Intachuen committed
474
475
    # The input embeds
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
476

477
478
    # Session info for continual prompting
    session_params: Optional[SessionParams] = None
479

480
481
482
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
483
484
    custom_logit_processor: Optional[str] = None

485
486
487
    # Whether to return hidden states
    return_hidden_states: bool = False

488
489
    # For disaggregated inference
    bootstrap_host: Optional[str] = None
490
    bootstrap_port: Optional[int] = None
491
492
    bootstrap_room: Optional[int] = None

493
494
495
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
496

497
498
499
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
woodx's avatar
woodx committed
500
    text: Optional[Union[List[List[str]], List[str], str]] = None
501
502
503
504
505
506
507
508
509
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
    image_data: Optional[
        Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
    ] = None
510
511
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
    audio_data: Optional[Union[List[str], str]] = None
512
513
514
515
516
517
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
    sampling_params: Union[List[Dict], Dict] = None
Rin Intachuen's avatar
Rin Intachuen committed
518
519
    # Dummy input embeds for compatibility
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
520
521
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
522
523
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
woodx's avatar
woodx committed
524
525
    # For cross-encoder requests
    is_cross_encoder_request: bool = False
526

527
    def normalize_batch_and_arguments(self):
528
529
530
531
532
533
534
535
536
        # at least one of text, input_ids, or image should be provided
        if self.text is None and self.input_ids is None and self.image_data is None:
            raise ValueError(
                "At least one of text, input_ids, or image should be provided"
            )

        # text and input_ids cannot be provided at the same time
        if self.text is not None and self.input_ids is not None:
            raise ValueError("text and input_ids cannot be provided at the same time")
537

538
        # Derive the batch size
539
540
541
542
        self.batch_size = 0
        self.is_single = True

        # check the batch size of text
543
        if self.text is not None:
544
545
            if isinstance(self.text, list):
                self.batch_size += len(self.text)
546
                self.is_single = False
547
            else:
548
549
550
551
552
553
                self.batch_size += 1

        # check the batch size of input_ids
        if self.input_ids is not None:
            if isinstance(self.input_ids[0], list):
                self.batch_size += len(self.input_ids)
554
                self.is_single = False
555
            else:
556
557
                self.batch_size += 1

558
        # Fill in default arguments
559
        if self.is_single:
560
561
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
562
            if self.sampling_params is None:
563
                self.sampling_params = {}
564
            self.sampling_params["max_new_tokens"] = 0
565
566
567
568
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
569
570
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
571
            if self.sampling_params is None:
572
573
                self.sampling_params = [{}] * self.batch_size
            for i in range(self.batch_size):
574
                self.sampling_params[i]["max_new_tokens"] = 0
575

576
577
578
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid
579

580
581
582
    def contains_mm_input(self) -> bool:
        return has_valid_data(self.image_data) or has_valid_data(self.audio_data)

583
    def __getitem__(self, i):
woodx's avatar
woodx committed
584
585
586
587
588
589
590
591
592
593
        if self.is_cross_encoder_request:
            return EmbeddingReqInput(
                text=[self.text[i]] if self.text is not None else None,
                input_ids=None,
                image_data=None,
                sampling_params=self.sampling_params[i],
                rid=self.rid[i],
                is_cross_encoder_request=True,
            )

594
595
596
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
597
            image_data=self.image_data[i] if self.image_data is not None else None,
598
599
600
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
        )
601
602
603


@dataclass
604
class TokenizedEmbeddingReqInput:
605
606
607
608
609
610
    # The request id
    rid: str
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
611
612
    # The image inputs
    image_inputs: dict
woodx's avatar
woodx committed
613
614
    # The token type ids
    token_type_ids: List[int]
615
616
617
618
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams


Lianmin Zheng's avatar
Lianmin Zheng committed
619
620
@dataclass
class BatchTokenIDOut:
621
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
622
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
623
624
625
    # The finish reason
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
626
    decoded_texts: List[str]
627
628
    decode_ids: List[int]
    read_offsets: List[int]
629
    # Only used when `--skip-tokenizer-init` is on
630
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
631
    # Detokenization configs
Lianmin Zheng's avatar
Lianmin Zheng committed
632
    skip_special_tokens: List[bool]
633
    spaces_between_special_tokens: List[bool]
634
    no_stop_trim: List[bool]
635

Lianmin Zheng's avatar
Lianmin Zheng committed
636
637
638
639
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
640
641
    spec_verify_ct: List[int]

Lianmin Zheng's avatar
Lianmin Zheng committed
642
643
644
645
646
647
648
649
650
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
651
652
653
654
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Lianmin Zheng's avatar
Lianmin Zheng committed
655

656
    # Hidden states
657
658
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
659

660
661
662
663
@dataclass
class BatchMultimodalDecodeReq:
    # The request id
    rids: List[str]
664
665
666
667
668
669
    finished_reasons: List[BaseFinishReason]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
670
671


Lianmin Zheng's avatar
Lianmin Zheng committed
672
673
@dataclass
class BatchStrOut:
674
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
675
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
676
677
    # The finish reason
    finished_reasons: List[dict]
678
    # The output decoded strings
679
    output_strs: List[str]
680
681
    # The token ids
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
682
683
684
685
686

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
687
    spec_verify_ct: List[int]
688

Lianmin Zheng's avatar
Lianmin Zheng committed
689
690
691
692
693
694
695
696
697
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
698
699
700
701
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Liangsheng Yin's avatar
Liangsheng Yin committed
702

703
    # Hidden states
704
705
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
706

707
708
709
710
@dataclass
class BatchMultimodalOut:
    # The request id
    rids: List[str]
711
712
713
714
715
716
717
718
719
    # The finish reason
    finished_reasons: List[dict]
    # The outputs
    outputs: List[List[Dict]]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
720
721


722
723
@dataclass
class BatchEmbeddingOut:
724
    # The request id
725
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
726
727
    # The finish reason
    finished_reasons: List[BaseFinishReason]
728
    # The output embedding
729
    embeddings: List[List[float]]
Lianmin Zheng's avatar
Lianmin Zheng committed
730
731
    # Token counts
    prompt_tokens: List[int]
732
    cached_tokens: List[int]
733
734


Liangsheng Yin's avatar
Liangsheng Yin committed
735
@dataclass
736
class FlushCacheReqInput:
Liangsheng Yin's avatar
Liangsheng Yin committed
737
    pass
Cody Yu's avatar
Cody Yu committed
738

739

740
741
742
743
744
@dataclass
class FlushCacheReqOutput:
    success: bool


745
@dataclass
Chayenne's avatar
Chayenne committed
746
class UpdateWeightFromDiskReqInput:
747
748
749
750
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None
751
752
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
753
754
755


@dataclass
Chayenne's avatar
Chayenne committed
756
class UpdateWeightFromDiskReqOutput:
757
758
    success: bool
    message: str
759
760
    # Number of paused requests during weight sync.
    num_paused_requests: Optional[int] = 0
761
762


763
764
@dataclass
class UpdateWeightsFromDistributedReqInput:
765
766
767
768
769
770
771
    names: List[str]
    dtypes: List[str]
    shapes: List[List[int]]
    # The group name
    group_name: str = "weight_update_group"
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
772
773
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
774
775
776
777
778
779
780
781


@dataclass
class UpdateWeightsFromDistributedReqOutput:
    success: bool
    message: str


782
783
@dataclass
class UpdateWeightsFromTensorReqInput:
784
785
786
787
788
789
790
791
792
793
794
    """Update model weights from tensor input.

    - Tensors are serialized for transmission
    - Data is structured in JSON for easy transmission over HTTP
    """

    serialized_named_tensors: List[Union[str, bytes]]
    # Optional format specification for loading
    load_format: Optional[str] = None
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
795
796
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
797
798
799
800
801
802
803
804


@dataclass
class UpdateWeightsFromTensorReqOutput:
    success: bool
    message: str


805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
@dataclass
class InitWeightsUpdateGroupReqInput:
    # The master address
    master_address: str
    # The master port
    master_port: int
    # The rank offset
    rank_offset: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_update_group"
    # The backend
    backend: str = "nccl"


@dataclass
class InitWeightsUpdateGroupReqOutput:
    success: bool
    message: str


827
828
829
830
831
832
833
834
835
836
837
@dataclass
class GetWeightsByNameReqInput:
    name: str
    truncate_size: int = 100


@dataclass
class GetWeightsByNameReqOutput:
    parameter: list


838
839
@dataclass
class ReleaseMemoryOccupationReqInput:
840
841
842
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
843
844
845
846
847
848
849
850
851


@dataclass
class ReleaseMemoryOccupationReqOutput:
    pass


@dataclass
class ResumeMemoryOccupationReqInput:
852
853
854
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
855
856
857
858
859
860
861


@dataclass
class ResumeMemoryOccupationReqOutput:
    pass


862
863
864
865
866
867
868
869
870
871
@dataclass
class SlowDownReqInput:
    forward_sleep_time: Optional[float]


@dataclass
class SlowDownReqOutput:
    pass


872
873
@dataclass
class AbortReq:
874
    # The request id
875
876
877
    rid: str = ""
    # Whether to abort all requests
    abort_all: bool = False
878
879


880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
@dataclass
class GetInternalStateReq:
    pass


@dataclass
class GetInternalStateReqOutput:
    internal_state: Dict[Any, Any]


@dataclass
class SetInternalStateReq:
    server_args: Dict[str, Any]


@dataclass
class SetInternalStateReqOutput:
    updated: bool
    server_args: Dict[str, Any]


@dataclass
class ProfileReqInput:
    # The output directory
    output_dir: Optional[str] = None
    # If set, it profile as many as this number of steps.
    # If it is set, profiling is automatically stopped after this step, and
    # the caller doesn't need to run stop_profile.
908
    start_step: Optional[int] = None
909
    num_steps: Optional[int] = None
910
911
    activities: Optional[List[str]] = None
    profile_by_stage: bool = False
912
913
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
914
915
916


class ProfileReqType(Enum):
917
918
    START_PROFILE = 1
    STOP_PROFILE = 2
919
920


921
922
923
924
925
926
class ExpertDistributionReq(Enum):
    START_RECORD = 1
    STOP_RECORD = 2
    DUMP_RECORD = 3


927
928
929
930
931
@dataclass
class ExpertDistributionReqOutput:
    pass


932
933
934
935
@dataclass
class ProfileReq:
    type: ProfileReqType
    output_dir: Optional[str] = None
936
    start_step: Optional[int] = None
937
938
    num_steps: Optional[int] = None
    activities: Optional[List[str]] = None
939
    profile_by_stage: bool = False
940
941
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
942
    profile_id: Optional[str] = None
943
944
945
946
947
948
949
950


@dataclass
class ProfileReqOutput:
    success: bool
    message: str


951
952
953
@dataclass
class ConfigureLoggingReq:
    log_requests: Optional[bool] = None
954
    log_requests_level: Optional[int] = None
955
956
957
958
    dump_requests_folder: Optional[str] = None
    dump_requests_threshold: Optional[int] = None


959
960
961
@dataclass
class OpenSessionReqInput:
    capacity_of_str_len: int
962
    session_id: Optional[str] = None
963
964
965
966
967
968
969
970
971


@dataclass
class CloseSessionReqInput:
    session_id: str


@dataclass
class OpenSessionReqOutput:
972
973
    session_id: Optional[str]
    success: bool
YAMY's avatar
YAMY committed
974
975


976
977
978
979
980
@dataclass
class HealthCheckOutput:
    pass


YAMY's avatar
YAMY committed
981
982
983
984
985
986
987
988
989
990
991
992
993
994
@dataclass
class Function:
    description: Optional[str] = None
    name: Optional[str] = None
    parameters: Optional[object] = None


@dataclass
class Tool:
    function: Function
    type: Optional[str] = "function"


@dataclass
995
class ParseFunctionCallReq:
YAMY's avatar
YAMY committed
996
997
998
999
1000
1001
1002
    text: str  # The text to parse.
    tools: List[Tool] = field(
        default_factory=list
    )  # A list of available function tools (name, parameters, etc.).
    tool_call_parser: Optional[str] = (
        None  # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
    )
1003
1004


Xihuai Wang's avatar
Xihuai Wang committed
1005
1006
1007
1008
1009
1010
@dataclass
class SeparateReasoningReqInput:
    text: str  # The text to parse.
    reasoning_parser: str  # Specify the parser type, e.g., "deepseek-r1".


1011
1012
1013
1014
@dataclass
class VertexGenerateReqInput:
    instances: List[dict]
    parameters: Optional[dict] = None
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026


@dataclass
class RpcReqInput:
    method: str
    parameters: Optional[Dict] = None


@dataclass
class RpcReqOutput:
    success: bool
    message: str
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050


@dataclass
class LoadLoRAAdapterReqInput:
    # The name of the lora module to newly loaded.
    lora_name: str
    # The path of loading.
    lora_path: str


@dataclass
class UnloadLoRAAdapterReqInput:
    # The name of lora module to unload.
    lora_name: str


@dataclass
class LoRAUpdateResult:
    success: bool
    error_message: Optional[str] = None
    loaded_adapters: Dict[str, str] = field(default_factory=dict)


LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult