io_struct.py 33.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""
15
The definition of objects transferred between different
16
processes (TokenizerManager, DetokenizerManager, Controller).
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
"""

19
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import uuid
YAMY's avatar
YAMY committed
21
from dataclasses import dataclass, field
22
from enum import Enum
23
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
24

25
26
from sglang.srt.mm_utils import has_valid_data

27
28
29
30
31
# handle serialization of Image for pydantic
if TYPE_CHECKING:
    from PIL.Image import Image
else:
    Image = Any
32

33
from sglang.srt.managers.schedule_batch import BaseFinishReason
34
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
35
36


37
38
39
40
41
42
43
44
@dataclass
class SessionParams:
    id: Optional[str] = None
    rid: Optional[str] = None
    offset: Optional[int] = None
    replace: Optional[bool] = None


45
46
47
48
AudioDataItem = Union[str, Dict]
ImageDataItem = Union[Image, str, Dict]


Lianmin Zheng's avatar
Lianmin Zheng committed
49
50
@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
51
    # The input prompt. It can be a single prompt or a batch of prompts.
52
    text: Optional[Union[List[str], str]] = None
Rin Intachuen's avatar
Rin Intachuen committed
53
    # The token ids for text; one can specify either text or input_ids
54
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Rin Intachuen's avatar
Rin Intachuen committed
55
56
    # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
57
58
59
60
61
62
63
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
    image_data: Optional[
64
        Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]
65
66
    ] = None
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
67
    audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
68
    # The sampling_params. See descriptions below.
69
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
70
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
71
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
72
    # Whether to return logprobs.
73
    return_logprob: Optional[Union[List[bool], bool]] = None
74
    # If return logprobs, the start location in the prompt for returning logprobs.
75
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
76
    logprob_start_len: Optional[Union[List[int], int]] = None
77
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
78
    top_logprobs_num: Optional[Union[List[int], int]] = None
79
80
    # If return logprobs, the token ids to return logprob for.
    token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
81
    # Whether to detokenize tokens in text in the returned logprobs.
82
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
83
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
84
    stream: bool = False
85
86
87
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True

88
89
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
90
    # The path to the LoRA
91
92
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

93
94
    # Session info for continual prompting
    session_params: Optional[Union[List[Dict], Dict]] = None
95

96
97
98
99
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
    custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
100

101
    # Whether to return hidden states
102
    return_hidden_states: Union[List[bool], bool] = False
103

104
    # For disaggregated inference
105
    bootstrap_host: Optional[Union[List[str], str]] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
106
    bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
107
    bootstrap_room: Optional[Union[List[int], int]] = None
108

109
110
111
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

112
113
114
    def contains_mm_input(self) -> bool:
        return has_valid_data(self.image_data) or has_valid_data(self.audio_data)

115
    def normalize_batch_and_arguments(self):
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        """
        Normalize the batch size and arguments for the request.

        This method resolves various input formats and ensures all parameters
        are properly formatted as either single values or batches depending on the input.
        It also handles parallel sampling expansion and sets default values for
        unspecified parameters.

        Raises:
            ValueError: If inputs are not properly specified (e.g., none or all of
                       text, input_ids, input_embeds are provided)
        """
        self._validate_inputs()
        self._determine_batch_size()
        self._handle_parallel_sampling()

        if self.is_single:
            self._normalize_single_inputs()
        else:
            self._normalize_batch_inputs()

        self._validate_session_params()

    def _validate_inputs(self):
        """Validate that the input configuration is valid."""
Rin Intachuen's avatar
Rin Intachuen committed
141
142
143
144
145
146
        if (
            self.text is None and self.input_ids is None and self.input_embeds is None
        ) or (
            self.text is not None
            and self.input_ids is not None
            and self.input_embeds is not None
147
        ):
Rin Intachuen's avatar
Rin Intachuen committed
148
149
150
            raise ValueError(
                "Either text, input_ids or input_embeds should be provided."
            )
151

152
153
    def _determine_batch_size(self):
        """Determine if this is a single example or a batch and the batch size."""
154
155
156
157
158
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
159
                self.is_single = False
160
                self.batch_size = len(self.text)
Rin Intachuen's avatar
Rin Intachuen committed
161
162
            self.input_embeds = None
        elif self.input_ids is not None:
Yinghai Lu's avatar
Yinghai Lu committed
163
164
            if len(self.input_ids) == 0:
                raise ValueError("input_ids cannot be empty.")
165
166
167
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
168
            else:
169
                self.is_single = False
170
                self.batch_size = len(self.input_ids)
Rin Intachuen's avatar
Rin Intachuen committed
171
172
173
174
175
176
            self.input_embeds = None
        else:
            if isinstance(self.input_embeds[0][0], float):
                self.is_single = True
                self.batch_size = 1
            else:
177
                self.is_single = False
Rin Intachuen's avatar
Rin Intachuen committed
178
                self.batch_size = len(self.input_embeds)
179

180
181
182
    def _handle_parallel_sampling(self):
        """Handle parallel sampling parameters and adjust batch size if needed."""
        # Determine parallel sample count
183
184
        if self.sampling_params is None:
            self.parallel_sample_num = 1
185
        elif isinstance(self.sampling_params, dict):
186
187
188
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
189
190
191
192
193
            for sampling_params in self.sampling_params:
                if self.parallel_sample_num != sampling_params.get("n", 1):
                    raise ValueError(
                        "The parallel_sample_num should be the same for all samples in sample params."
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
194

195
        # If using parallel sampling with a single example, convert to batch
196
197
198
199
200
201
202
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    def _normalize_single_inputs(self):
        """Normalize inputs for a single example."""
        if self.sampling_params is None:
            self.sampling_params = {}
        if self.rid is None:
            self.rid = uuid.uuid4().hex
        if self.return_logprob is None:
            self.return_logprob = False
        if self.logprob_start_len is None:
            self.logprob_start_len = -1
        if self.top_logprobs_num is None:
            self.top_logprobs_num = 0
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = None

    def _normalize_batch_inputs(self):
        """Normalize inputs for a batch of examples, including parallel sampling expansion."""
        # Calculate expanded batch size
        if self.parallel_sample_num == 1:
            num = self.batch_size
Lianmin Zheng's avatar
Lianmin Zheng committed
223
        else:
224
225
226
227
228
            # Expand parallel_sample_num
            num = self.batch_size * self.parallel_sample_num

        # Expand input based on type
        self._expand_inputs(num)
229
        self._normalize_rid(num)
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        self._normalize_lora_paths(num)
        self._normalize_image_data(num)
        self._normalize_audio_data(num)
        self._normalize_sampling_params(num)
        self._normalize_logprob_params(num)
        self._normalize_custom_logit_processor(num)

    def _expand_inputs(self, num):
        """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
        if self.text is not None:
            if not isinstance(self.text, list):
                raise ValueError("Text should be a list for batch processing.")
            self.text = self.text * self.parallel_sample_num
        elif self.input_ids is not None:
            if not isinstance(self.input_ids, list) or not isinstance(
                self.input_ids[0], list
            ):
                raise ValueError(
                    "input_ids should be a list of lists for batch processing."
                )
            self.input_ids = self.input_ids * self.parallel_sample_num
        elif self.input_embeds is not None:
            if not isinstance(self.input_embeds, list):
                raise ValueError("input_embeds should be a list for batch processing.")
            self.input_embeds = self.input_embeds * self.parallel_sample_num

    def _normalize_lora_paths(self, num):
        """Normalize LoRA paths for batch processing."""
        if self.lora_path is not None:
            if isinstance(self.lora_path, str):
                self.lora_path = [self.lora_path] * num
            elif isinstance(self.lora_path, list):
                self.lora_path = self.lora_path * self.parallel_sample_num
263
            else:
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
                raise ValueError("lora_path should be a list or a string.")

    def _normalize_image_data(self, num):
        """Normalize image data for batch processing."""
        if self.image_data is None:
            self.image_data = [None] * num
        elif not isinstance(self.image_data, list):
            # Single image, convert to list of single-image lists
            self.image_data = [[self.image_data]] * num
            self.modalities = ["image"] * num
        elif isinstance(self.image_data, list):
            if len(self.image_data) != self.batch_size:
                raise ValueError(
                    "The length of image_data should be equal to the batch size."
                )

            self.modalities = []
            if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
                # Already a list of lists, keep as is
                for i in range(len(self.image_data)):
                    if self.image_data[i] is None or self.image_data[i] == [None]:
                        self.modalities.append(None)
                    elif len(self.image_data[i]) == 1:
                        self.modalities.append("image")
                    elif len(self.image_data[i]) > 1:
                        self.modalities.append("multi-images")
290
                # Expand parallel_sample_num
291
292
                self.image_data = self.image_data * self.parallel_sample_num
                self.modalities = self.modalities * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
293
            else:
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
                # List of images for a batch, wrap each in a list
                wrapped_images = [[img] for img in self.image_data]
                # Expand for parallel sampling
                self.image_data = wrapped_images * self.parallel_sample_num
                self.modalities = ["image"] * num

    def _normalize_audio_data(self, num):
        """Normalize audio data for batch processing."""
        if self.audio_data is None:
            self.audio_data = [None] * num
        elif not isinstance(self.audio_data, list):
            self.audio_data = [self.audio_data] * num
        elif isinstance(self.audio_data, list):
            self.audio_data = self.audio_data * self.parallel_sample_num

    def _normalize_sampling_params(self, num):
        """Normalize sampling parameters for batch processing."""
        if self.sampling_params is None:
            self.sampling_params = [{}] * num
        elif isinstance(self.sampling_params, dict):
            self.sampling_params = [self.sampling_params] * num
        else:  # Already a list
            self.sampling_params = self.sampling_params * self.parallel_sample_num

    def _normalize_rid(self, num):
        """Normalize request IDs for batch processing."""
        if self.rid is None:
            self.rid = [uuid.uuid4().hex for _ in range(num)]
322
323
324
325
326
327
328
329
330
331
        elif isinstance(self.rid, str):
            new_rids = [f"{self.rid}_{i}" for i in range(num)]
            self.rid = new_rids
        elif isinstance(self.rid, list):
            if len(self.rid) != num:
                raise ValueError(
                    "The specified rids length mismatch with the batch_size for batch processing."
                )
        else:
            raise ValueError("The rid should be a string or a list of strings.")
332
333
334
335
336
337
338
339
340
341

    def _normalize_logprob_params(self, num):
        """Normalize logprob-related parameters for batch processing."""

        # Helper function to normalize a parameter
        def normalize_param(param, default_value, param_name):
            if param is None:
                return [default_value] * num
            elif not isinstance(param, list):
                return [param] * num
342
            else:
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
                if self.parallel_sample_num > 1:
                    raise ValueError(
                        f"Cannot use list {param_name} with parallel_sample_num > 1"
                    )
                return param

        # Normalize each logprob parameter
        self.return_logprob = normalize_param(
            self.return_logprob, False, "return_logprob"
        )
        self.logprob_start_len = normalize_param(
            self.logprob_start_len, -1, "logprob_start_len"
        )
        self.top_logprobs_num = normalize_param(
            self.top_logprobs_num, 0, "top_logprobs_num"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
359

360
361
362
363
364
365
366
367
368
369
370
371
372
        # Handle token_ids_logprob specially due to its nested structure
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = [None] * num
        elif not isinstance(self.token_ids_logprob, list):
            self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
        elif not isinstance(self.token_ids_logprob[0], list):
            self.token_ids_logprob = [
                copy.deepcopy(self.token_ids_logprob) for _ in range(num)
            ]
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list token_ids_logprob with parallel_sample_num > 1"
            )
373

374
375
376
377
378
379
380
381
382
383
    def _normalize_custom_logit_processor(self, num):
        """Normalize custom logit processor for batch processing."""
        if self.custom_logit_processor is None:
            self.custom_logit_processor = [None] * num
        elif not isinstance(self.custom_logit_processor, list):
            self.custom_logit_processor = [self.custom_logit_processor] * num
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list custom_logit_processor with parallel_sample_num > 1"
            )
384

385
386
    def _validate_session_params(self):
        """Validate that session parameters are properly formatted."""
387
        if self.session_params is not None:
388
            if not isinstance(self.session_params, dict) and not isinstance(
389
                self.session_params[0], dict
390
391
            ):
                raise ValueError("Session params must be a dict or a list of dicts.")
392

393
    def regenerate_rid(self):
394
        """Generate a new request ID and return it."""
395
396
397
398
399
400
401
402
        self.rid = uuid.uuid4().hex
        return self.rid

    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
            image_data=self.image_data[i],
Mick's avatar
Mick committed
403
            audio_data=self.audio_data[i],
404
405
406
407
408
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
409
            token_ids_logprob=self.token_ids_logprob[i],
410
411
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
412
            log_metrics=self.log_metrics,
413
414
            modalities=self.modalities[i] if self.modalities else None,
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
415
416
417
418
419
            custom_logit_processor=(
                self.custom_logit_processor[i]
                if self.custom_logit_processor is not None
                else None
            ),
420
421
422
423
424
            return_hidden_states=(
                self.return_hidden_states[i]
                if isinstance(self.return_hidden_states, list)
                else self.return_hidden_states
            ),
425
            # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
426
427
428
            bootstrap_host=(
                self.bootstrap_host[i] if self.bootstrap_host is not None else None
            ),
429
430
431
            bootstrap_port=(
                self.bootstrap_port[i] if self.bootstrap_port is not None else None
            ),
432
433
434
            bootstrap_room=(
                self.bootstrap_room[i] if self.bootstrap_room is not None else None
            ),
435
436
437
            data_parallel_rank=(
                self.data_parallel_rank if self.data_parallel_rank is not None else None
            ),
438
439
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
440
441
442

@dataclass
class TokenizedGenerateReqInput:
443
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
444
    rid: str
445
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
446
    input_text: str
447
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
448
    input_ids: List[int]
Mick's avatar
Mick committed
449
450
    # The multimodal inputs
    mm_inputs: dict
451
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
452
    sampling_params: SamplingParams
453
    # Whether to return the logprobs
454
    return_logprob: bool
455
    # If return logprobs, the start location in the prompt for returning logprobs.
456
    logprob_start_len: int
457
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
458
    top_logprobs_num: int
459
460
    # If return logprobs, the token id to return logprob for
    token_ids_logprob: List[int]
461
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
462
463
    stream: bool

464
465
    # LoRA related
    lora_path: Optional[str] = None  # None means just use the base model
Rin Intachuen's avatar
Rin Intachuen committed
466
467
    # The input embeds
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
468

469
470
    # Session info for continual prompting
    session_params: Optional[SessionParams] = None
471

472
473
474
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
475
476
    custom_logit_processor: Optional[str] = None

477
478
479
    # Whether to return hidden states
    return_hidden_states: bool = False

480
481
    # For disaggregated inference
    bootstrap_host: Optional[str] = None
482
    bootstrap_port: Optional[int] = None
483
484
    bootstrap_room: Optional[int] = None

485
486
487
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
488

489
490
491
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
woodx's avatar
woodx committed
492
    text: Optional[Union[List[List[str]], List[str], str]] = None
493
494
495
496
497
498
499
500
501
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
    image_data: Optional[
        Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
    ] = None
502
503
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
    audio_data: Optional[Union[List[str], str]] = None
504
505
506
507
508
509
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
    sampling_params: Union[List[Dict], Dict] = None
Rin Intachuen's avatar
Rin Intachuen committed
510
511
    # Dummy input embeds for compatibility
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
512
513
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
514
515
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
woodx's avatar
woodx committed
516
517
    # For cross-encoder requests
    is_cross_encoder_request: bool = False
518

519
520
521
    def contains_mm_input(self) -> bool:
        return has_valid_data(self.image_data) or has_valid_data(self.audio_data)

522
    def normalize_batch_and_arguments(self):
523
524
525
526
527
528
529
530
531
        # at least one of text, input_ids, or image should be provided
        if self.text is None and self.input_ids is None and self.image_data is None:
            raise ValueError(
                "At least one of text, input_ids, or image should be provided"
            )

        # text and input_ids cannot be provided at the same time
        if self.text is not None and self.input_ids is not None:
            raise ValueError("text and input_ids cannot be provided at the same time")
532

533
        # Derive the batch size
534
535
536
537
        self.batch_size = 0
        self.is_single = True

        # check the batch size of text
538
        if self.text is not None:
539
540
            if isinstance(self.text, list):
                self.batch_size += len(self.text)
541
                self.is_single = False
542
            else:
543
544
545
546
547
548
                self.batch_size += 1

        # check the batch size of input_ids
        if self.input_ids is not None:
            if isinstance(self.input_ids[0], list):
                self.batch_size += len(self.input_ids)
549
                self.is_single = False
550
            else:
551
552
                self.batch_size += 1

553
        # Fill in default arguments
554
        if self.is_single:
555
556
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
557
            if self.sampling_params is None:
558
                self.sampling_params = {}
559
            self.sampling_params["max_new_tokens"] = 0
560
561
562
563
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
564
565
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
566
            if self.sampling_params is None:
567
568
                self.sampling_params = [{}] * self.batch_size
            for i in range(self.batch_size):
569
                self.sampling_params[i]["max_new_tokens"] = 0
570

571
572
573
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid
574

575
    def __getitem__(self, i):
woodx's avatar
woodx committed
576
577
578
579
580
581
582
583
584
585
        if self.is_cross_encoder_request:
            return EmbeddingReqInput(
                text=[self.text[i]] if self.text is not None else None,
                input_ids=None,
                image_data=None,
                sampling_params=self.sampling_params[i],
                rid=self.rid[i],
                is_cross_encoder_request=True,
            )

586
587
588
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
589
            image_data=self.image_data[i] if self.image_data is not None else None,
590
591
592
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
        )
593
594
595


@dataclass
596
class TokenizedEmbeddingReqInput:
597
598
599
600
601
602
    # The request id
    rid: str
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
603
604
    # The image inputs
    image_inputs: dict
woodx's avatar
woodx committed
605
606
    # The token type ids
    token_type_ids: List[int]
607
608
609
610
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams


Lianmin Zheng's avatar
Lianmin Zheng committed
611
612
@dataclass
class BatchTokenIDOut:
613
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
614
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
615
616
617
    # The finish reason
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
618
    decoded_texts: List[str]
619
620
    decode_ids: List[int]
    read_offsets: List[int]
621
    # Only used when `--skip-tokenizer-init` is on
622
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
623
    # Detokenization configs
Lianmin Zheng's avatar
Lianmin Zheng committed
624
    skip_special_tokens: List[bool]
625
    spaces_between_special_tokens: List[bool]
626
    no_stop_trim: List[bool]
627

Lianmin Zheng's avatar
Lianmin Zheng committed
628
629
630
631
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
632
633
    spec_verify_ct: List[int]

Lianmin Zheng's avatar
Lianmin Zheng committed
634
635
636
637
638
639
640
641
642
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
643
644
645
646
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Lianmin Zheng's avatar
Lianmin Zheng committed
647

648
    # Hidden states
649
650
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
651

652
653
654
655
@dataclass
class BatchMultimodalDecodeReq:
    # The request id
    rids: List[str]
656
657
658
659
660
661
    finished_reasons: List[BaseFinishReason]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
662
663


Lianmin Zheng's avatar
Lianmin Zheng committed
664
665
@dataclass
class BatchStrOut:
666
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
667
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
668
669
    # The finish reason
    finished_reasons: List[dict]
670
    # The output decoded strings
671
    output_strs: List[str]
672
673
    # The token ids
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
674
675
676
677
678

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
679
    spec_verify_ct: List[int]
680

Lianmin Zheng's avatar
Lianmin Zheng committed
681
682
683
684
685
686
687
688
689
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
690
691
692
693
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Liangsheng Yin's avatar
Liangsheng Yin committed
694

695
    # Hidden states
696
697
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
698

699
700
701
702
@dataclass
class BatchMultimodalOut:
    # The request id
    rids: List[str]
703
704
705
706
707
708
709
710
711
    # The finish reason
    finished_reasons: List[dict]
    # The outputs
    outputs: List[List[Dict]]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
712
713


714
715
@dataclass
class BatchEmbeddingOut:
716
    # The request id
717
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
718
719
    # The finish reason
    finished_reasons: List[BaseFinishReason]
720
    # The output embedding
721
    embeddings: List[List[float]]
Lianmin Zheng's avatar
Lianmin Zheng committed
722
723
    # Token counts
    prompt_tokens: List[int]
724
    cached_tokens: List[int]
725
726


Liangsheng Yin's avatar
Liangsheng Yin committed
727
@dataclass
728
class FlushCacheReqInput:
Liangsheng Yin's avatar
Liangsheng Yin committed
729
    pass
Cody Yu's avatar
Cody Yu committed
730

731

732
733
734
735
736
@dataclass
class FlushCacheReqOutput:
    success: bool


737
@dataclass
Chayenne's avatar
Chayenne committed
738
class UpdateWeightFromDiskReqInput:
739
740
741
742
743
744
745
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None


@dataclass
Chayenne's avatar
Chayenne committed
746
class UpdateWeightFromDiskReqOutput:
747
748
    success: bool
    message: str
749
750
    # Number of paused requests during weight sync.
    num_paused_requests: Optional[int] = 0
751
752


753
754
755
756
757
758
759
760
761
762
763
764
765
@dataclass
class UpdateWeightsFromDistributedReqInput:
    name: str
    dtype: str
    shape: List[int]


@dataclass
class UpdateWeightsFromDistributedReqOutput:
    success: bool
    message: str


766
767
@dataclass
class UpdateWeightsFromTensorReqInput:
768
769
770
771
772
773
774
775
776
777
778
    """Update model weights from tensor input.

    - Tensors are serialized for transmission
    - Data is structured in JSON for easy transmission over HTTP
    """

    serialized_named_tensors: List[Union[str, bytes]]
    # Optional format specification for loading
    load_format: Optional[str] = None
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
779
780
781
782
783
784
785
786


@dataclass
class UpdateWeightsFromTensorReqOutput:
    success: bool
    message: str


787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
@dataclass
class InitWeightsUpdateGroupReqInput:
    # The master address
    master_address: str
    # The master port
    master_port: int
    # The rank offset
    rank_offset: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_update_group"
    # The backend
    backend: str = "nccl"


@dataclass
class InitWeightsUpdateGroupReqOutput:
    success: bool
    message: str


809
810
811
812
813
814
815
816
817
818
819
@dataclass
class GetWeightsByNameReqInput:
    name: str
    truncate_size: int = 100


@dataclass
class GetWeightsByNameReqOutput:
    parameter: list


820
821
@dataclass
class ReleaseMemoryOccupationReqInput:
822
823
824
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
825
826
827
828
829
830
831
832
833


@dataclass
class ReleaseMemoryOccupationReqOutput:
    pass


@dataclass
class ResumeMemoryOccupationReqInput:
834
835
836
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
837
838
839
840
841
842
843


@dataclass
class ResumeMemoryOccupationReqOutput:
    pass


844
845
846
847
848
849
850
851
852
853
@dataclass
class SlowDownReqInput:
    forward_sleep_time: Optional[float]


@dataclass
class SlowDownReqOutput:
    pass


854
855
@dataclass
class AbortReq:
856
    # The request id
857
    rid: str
858
859


860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
@dataclass
class GetInternalStateReq:
    pass


@dataclass
class GetInternalStateReqOutput:
    internal_state: Dict[Any, Any]


@dataclass
class SetInternalStateReq:
    server_args: Dict[str, Any]


@dataclass
class SetInternalStateReqOutput:
    updated: bool
    server_args: Dict[str, Any]


@dataclass
class ProfileReqInput:
    # The output directory
    output_dir: Optional[str] = None
    # If set, it profile as many as this number of steps.
    # If it is set, profiling is automatically stopped after this step, and
    # the caller doesn't need to run stop_profile.
    num_steps: Optional[int] = None
889
890
    activities: Optional[List[str]] = None
    profile_by_stage: bool = False
891
892
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
893
894
895


class ProfileReqType(Enum):
896
897
    START_PROFILE = 1
    STOP_PROFILE = 2
898
899


900
901
902
903
904
905
class ExpertDistributionReq(Enum):
    START_RECORD = 1
    STOP_RECORD = 2
    DUMP_RECORD = 3


906
907
908
909
910
@dataclass
class ExpertDistributionReqOutput:
    pass


911
912
913
914
915
916
@dataclass
class ProfileReq:
    type: ProfileReqType
    output_dir: Optional[str] = None
    num_steps: Optional[int] = None
    activities: Optional[List[str]] = None
917
    profile_by_stage: bool = False
918
919
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
920
    profile_id: Optional[str] = None
921
922
923
924
925
926
927
928


@dataclass
class ProfileReqOutput:
    success: bool
    message: str


929
930
931
@dataclass
class ConfigureLoggingReq:
    log_requests: Optional[bool] = None
932
    log_requests_level: Optional[int] = None
933
934
935
936
    dump_requests_folder: Optional[str] = None
    dump_requests_threshold: Optional[int] = None


937
938
939
@dataclass
class OpenSessionReqInput:
    capacity_of_str_len: int
940
    session_id: Optional[str] = None
941
942
943
944
945
946
947
948
949


@dataclass
class CloseSessionReqInput:
    session_id: str


@dataclass
class OpenSessionReqOutput:
950
951
    session_id: Optional[str]
    success: bool
YAMY's avatar
YAMY committed
952
953


954
955
956
957
958
@dataclass
class HealthCheckOutput:
    pass


YAMY's avatar
YAMY committed
959
960
961
962
963
964
965
966
967
968
969
970
971
972
@dataclass
class Function:
    description: Optional[str] = None
    name: Optional[str] = None
    parameters: Optional[object] = None


@dataclass
class Tool:
    function: Function
    type: Optional[str] = "function"


@dataclass
973
class ParseFunctionCallReq:
YAMY's avatar
YAMY committed
974
975
976
977
978
979
980
    text: str  # The text to parse.
    tools: List[Tool] = field(
        default_factory=list
    )  # A list of available function tools (name, parameters, etc.).
    tool_call_parser: Optional[str] = (
        None  # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
    )
981
982


Xihuai Wang's avatar
Xihuai Wang committed
983
984
985
986
987
988
@dataclass
class SeparateReasoningReqInput:
    text: str  # The text to parse.
    reasoning_parser: str  # Specify the parser type, e.g., "deepseek-r1".


989
990
991
992
@dataclass
class VertexGenerateReqInput:
    instances: List[dict]
    parameters: Optional[dict] = None
993
994
995
996
997
998
999
1000
1001
1002
1003
1004


@dataclass
class RpcReqInput:
    method: str
    parameters: Optional[Dict] = None


@dataclass
class RpcReqOutput:
    success: bool
    message: str