io_struct.py 30.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""
The definition of objects transfered between different
16
processes (TokenizerManager, DetokenizerManager, Controller).
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
"""

19
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import uuid
YAMY's avatar
YAMY committed
21
from dataclasses import dataclass, field
22
from enum import Enum
23
24
25
26
27
28
29
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union

# handle serialization of Image for pydantic
if TYPE_CHECKING:
    from PIL.Image import Image
else:
    Image = Any
30

31
from sglang.srt.managers.schedule_batch import BaseFinishReason
32
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
33
34


35
36
37
38
39
40
41
42
@dataclass
class SessionParams:
    id: Optional[str] = None
    rid: Optional[str] = None
    offset: Optional[int] = None
    replace: Optional[bool] = None


Lianmin Zheng's avatar
Lianmin Zheng committed
43
44
@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
45
    # The input prompt. It can be a single prompt or a batch of prompts.
46
    text: Optional[Union[List[str], str]] = None
Rin Intachuen's avatar
Rin Intachuen committed
47
    # The token ids for text; one can specify either text or input_ids
48
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Rin Intachuen's avatar
Rin Intachuen committed
49
50
    # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
51
52
53
54
55
56
57
58
59
60
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
    image_data: Optional[
        Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
    ] = None
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
Mick's avatar
Mick committed
61
    audio_data: Optional[Union[List[str], str]] = None
62
    # The sampling_params. See descriptions below.
63
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
64
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
65
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
66
    # Whether to return logprobs.
67
    return_logprob: Optional[Union[List[bool], bool]] = None
68
    # If return logprobs, the start location in the prompt for returning logprobs.
69
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
70
    logprob_start_len: Optional[Union[List[int], int]] = None
71
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
72
    top_logprobs_num: Optional[Union[List[int], int]] = None
73
74
    # If return logprobs, the token ids to return logprob for.
    token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
75
    # Whether to detokenize tokens in text in the returned logprobs.
76
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
77
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
78
    stream: bool = False
79
80
81
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True

82
83
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
84
85
86
    # LoRA related
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

87
88
    # Session info for continual prompting
    session_params: Optional[Union[List[Dict], Dict]] = None
89

90
91
92
93
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
    custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
94

95
96
97
    # Whether to return hidden states
    return_hidden_states: bool = False

98
    # For disaggregated inference
99
100
    bootstrap_host: Optional[Union[List[str], str]] = None
    bootstrap_room: Optional[Union[List[int], int]] = None
101

102
    def normalize_batch_and_arguments(self):
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        """
        Normalize the batch size and arguments for the request.

        This method resolves various input formats and ensures all parameters
        are properly formatted as either single values or batches depending on the input.
        It also handles parallel sampling expansion and sets default values for
        unspecified parameters.

        Raises:
            ValueError: If inputs are not properly specified (e.g., none or all of
                       text, input_ids, input_embeds are provided)
        """
        self._validate_inputs()
        self._determine_batch_size()
        self._handle_parallel_sampling()

        if self.is_single:
            self._normalize_single_inputs()
        else:
            self._normalize_batch_inputs()

        self._validate_session_params()

    def _validate_inputs(self):
        """Validate that the input configuration is valid."""
Rin Intachuen's avatar
Rin Intachuen committed
128
129
130
131
132
133
        if (
            self.text is None and self.input_ids is None and self.input_embeds is None
        ) or (
            self.text is not None
            and self.input_ids is not None
            and self.input_embeds is not None
134
        ):
Rin Intachuen's avatar
Rin Intachuen committed
135
136
137
            raise ValueError(
                "Either text, input_ids or input_embeds should be provided."
            )
138

139
140
    def _determine_batch_size(self):
        """Determine if this is a single example or a batch and the batch size."""
141
142
143
144
145
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
146
                self.is_single = False
147
                self.batch_size = len(self.text)
Rin Intachuen's avatar
Rin Intachuen committed
148
149
            self.input_embeds = None
        elif self.input_ids is not None:
Yinghai Lu's avatar
Yinghai Lu committed
150
151
            if len(self.input_ids) == 0:
                raise ValueError("input_ids cannot be empty.")
152
153
154
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
155
            else:
156
                self.is_single = False
157
                self.batch_size = len(self.input_ids)
Rin Intachuen's avatar
Rin Intachuen committed
158
159
160
161
162
163
            self.input_embeds = None
        else:
            if isinstance(self.input_embeds[0][0], float):
                self.is_single = True
                self.batch_size = 1
            else:
164
                self.is_single = False
Rin Intachuen's avatar
Rin Intachuen committed
165
                self.batch_size = len(self.input_embeds)
166

167
168
169
    def _handle_parallel_sampling(self):
        """Handle parallel sampling parameters and adjust batch size if needed."""
        # Determine parallel sample count
170
171
        if self.sampling_params is None:
            self.parallel_sample_num = 1
172
        elif isinstance(self.sampling_params, dict):
173
174
175
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
176
177
178
179
180
            for sampling_params in self.sampling_params:
                if self.parallel_sample_num != sampling_params.get("n", 1):
                    raise ValueError(
                        "The parallel_sample_num should be the same for all samples in sample params."
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
181

182
        # If using parallel sampling with a single example, convert to batch
183
184
185
186
187
188
189
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    def _normalize_single_inputs(self):
        """Normalize inputs for a single example."""
        if self.sampling_params is None:
            self.sampling_params = {}
        if self.rid is None:
            self.rid = uuid.uuid4().hex
        if self.return_logprob is None:
            self.return_logprob = False
        if self.logprob_start_len is None:
            self.logprob_start_len = -1
        if self.top_logprobs_num is None:
            self.top_logprobs_num = 0
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = None

    def _normalize_batch_inputs(self):
        """Normalize inputs for a batch of examples, including parallel sampling expansion."""
        # Calculate expanded batch size
        if self.parallel_sample_num == 1:
            num = self.batch_size
Lianmin Zheng's avatar
Lianmin Zheng committed
210
        else:
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
            # Expand parallel_sample_num
            num = self.batch_size * self.parallel_sample_num

        # Expand input based on type
        self._expand_inputs(num)
        self._normalize_lora_paths(num)
        self._normalize_image_data(num)
        self._normalize_audio_data(num)
        self._normalize_sampling_params(num)
        self._normalize_rid(num)
        self._normalize_logprob_params(num)
        self._normalize_custom_logit_processor(num)

    def _expand_inputs(self, num):
        """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
        if self.text is not None:
            if not isinstance(self.text, list):
                raise ValueError("Text should be a list for batch processing.")
            self.text = self.text * self.parallel_sample_num
        elif self.input_ids is not None:
            if not isinstance(self.input_ids, list) or not isinstance(
                self.input_ids[0], list
            ):
                raise ValueError(
                    "input_ids should be a list of lists for batch processing."
                )
            self.input_ids = self.input_ids * self.parallel_sample_num
        elif self.input_embeds is not None:
            if not isinstance(self.input_embeds, list):
                raise ValueError("input_embeds should be a list for batch processing.")
            self.input_embeds = self.input_embeds * self.parallel_sample_num

    def _normalize_lora_paths(self, num):
        """Normalize LoRA paths for batch processing."""
        if self.lora_path is not None:
            if isinstance(self.lora_path, str):
                self.lora_path = [self.lora_path] * num
            elif isinstance(self.lora_path, list):
                self.lora_path = self.lora_path * self.parallel_sample_num
250
            else:
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
                raise ValueError("lora_path should be a list or a string.")

    def _normalize_image_data(self, num):
        """Normalize image data for batch processing."""
        if self.image_data is None:
            self.image_data = [None] * num
        elif not isinstance(self.image_data, list):
            # Single image, convert to list of single-image lists
            self.image_data = [[self.image_data]] * num
            self.modalities = ["image"] * num
        elif isinstance(self.image_data, list):
            if len(self.image_data) != self.batch_size:
                raise ValueError(
                    "The length of image_data should be equal to the batch size."
                )

            self.modalities = []
            if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
                # Already a list of lists, keep as is
                for i in range(len(self.image_data)):
                    if self.image_data[i] is None or self.image_data[i] == [None]:
                        self.modalities.append(None)
                    elif len(self.image_data[i]) == 1:
                        self.modalities.append("image")
                    elif len(self.image_data[i]) > 1:
                        self.modalities.append("multi-images")
277
                # Expand parallel_sample_num
278
279
                self.image_data = self.image_data * self.parallel_sample_num
                self.modalities = self.modalities * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
280
            else:
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
                # List of images for a batch, wrap each in a list
                wrapped_images = [[img] for img in self.image_data]
                # Expand for parallel sampling
                self.image_data = wrapped_images * self.parallel_sample_num
                self.modalities = ["image"] * num

    def _normalize_audio_data(self, num):
        """Normalize audio data for batch processing."""
        if self.audio_data is None:
            self.audio_data = [None] * num
        elif not isinstance(self.audio_data, list):
            self.audio_data = [self.audio_data] * num
        elif isinstance(self.audio_data, list):
            self.audio_data = self.audio_data * self.parallel_sample_num

    def _normalize_sampling_params(self, num):
        """Normalize sampling parameters for batch processing."""
        if self.sampling_params is None:
            self.sampling_params = [{}] * num
        elif isinstance(self.sampling_params, dict):
            self.sampling_params = [self.sampling_params] * num
        else:  # Already a list
            self.sampling_params = self.sampling_params * self.parallel_sample_num

    def _normalize_rid(self, num):
        """Normalize request IDs for batch processing."""
        if self.rid is None:
            self.rid = [uuid.uuid4().hex for _ in range(num)]
        elif not isinstance(self.rid, list):
            raise ValueError("The rid should be a list for batch processing.")

    def _normalize_logprob_params(self, num):
        """Normalize logprob-related parameters for batch processing."""

        # Helper function to normalize a parameter
        def normalize_param(param, default_value, param_name):
            if param is None:
                return [default_value] * num
            elif not isinstance(param, list):
                return [param] * num
321
            else:
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
                if self.parallel_sample_num > 1:
                    raise ValueError(
                        f"Cannot use list {param_name} with parallel_sample_num > 1"
                    )
                return param

        # Normalize each logprob parameter
        self.return_logprob = normalize_param(
            self.return_logprob, False, "return_logprob"
        )
        self.logprob_start_len = normalize_param(
            self.logprob_start_len, -1, "logprob_start_len"
        )
        self.top_logprobs_num = normalize_param(
            self.top_logprobs_num, 0, "top_logprobs_num"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
338

339
340
341
342
343
344
345
346
347
348
349
350
351
        # Handle token_ids_logprob specially due to its nested structure
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = [None] * num
        elif not isinstance(self.token_ids_logprob, list):
            self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
        elif not isinstance(self.token_ids_logprob[0], list):
            self.token_ids_logprob = [
                copy.deepcopy(self.token_ids_logprob) for _ in range(num)
            ]
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list token_ids_logprob with parallel_sample_num > 1"
            )
352

353
354
355
356
357
358
359
360
361
362
    def _normalize_custom_logit_processor(self, num):
        """Normalize custom logit processor for batch processing."""
        if self.custom_logit_processor is None:
            self.custom_logit_processor = [None] * num
        elif not isinstance(self.custom_logit_processor, list):
            self.custom_logit_processor = [self.custom_logit_processor] * num
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list custom_logit_processor with parallel_sample_num > 1"
            )
363

364
365
    def _validate_session_params(self):
        """Validate that session parameters are properly formatted."""
366
        if self.session_params is not None:
367
            if not isinstance(self.session_params, dict) and not isinstance(
368
                self.session_params[0], dict
369
370
            ):
                raise ValueError("Session params must be a dict or a list of dicts.")
371

372
    def regenerate_rid(self):
373
        """Generate a new request ID and return it."""
374
375
376
377
378
379
380
381
        self.rid = uuid.uuid4().hex
        return self.rid

    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
            image_data=self.image_data[i],
Mick's avatar
Mick committed
382
            audio_data=self.audio_data[i],
383
384
385
386
387
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
388
            token_ids_logprob=self.token_ids_logprob[i],
389
390
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
391
            log_metrics=self.log_metrics,
392
393
            modalities=self.modalities[i] if self.modalities else None,
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
394
395
396
397
398
            custom_logit_processor=(
                self.custom_logit_processor[i]
                if self.custom_logit_processor is not None
                else None
            ),
399
            return_hidden_states=self.return_hidden_states,
400
401
402
403
404
405
            bootstrap_host=(
                self.bootstrap_host[i] if self.bootstrap_host is not None else None
            ),
            bootstrap_room=(
                self.bootstrap_room[i] if self.bootstrap_room is not None else None
            ),
406
407
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
408
409
410

@dataclass
class TokenizedGenerateReqInput:
411
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
412
    rid: str
413
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
414
    input_text: str
415
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
416
    input_ids: List[int]
Mick's avatar
Mick committed
417
418
    # The multimodal inputs
    mm_inputs: dict
419
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
420
    sampling_params: SamplingParams
421
    # Whether to return the logprobs
422
    return_logprob: bool
423
    # If return logprobs, the start location in the prompt for returning logprobs.
424
    logprob_start_len: int
425
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
426
    top_logprobs_num: int
427
428
    # If return logprobs, the token id to return logprob for
    token_ids_logprob: List[int]
429
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
430
431
    stream: bool

432
433
    # LoRA related
    lora_path: Optional[str] = None  # None means just use the base model
Rin Intachuen's avatar
Rin Intachuen committed
434
435
    # The input embeds
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
436

437
438
    # Session info for continual prompting
    session_params: Optional[SessionParams] = None
439

440
441
442
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
443
444
    custom_logit_processor: Optional[str] = None

445
446
447
    # Whether to return hidden states
    return_hidden_states: bool = False

448
449
450
451
    # For disaggregated inference
    bootstrap_host: Optional[str] = None
    bootstrap_room: Optional[int] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
452

453
454
455
456
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
    text: Optional[Union[List[str], str]] = None
457
458
459
460
461
462
463
464
465
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
    image_data: Optional[
        Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
    ] = None
466
467
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
    audio_data: Optional[Union[List[str], str]] = None
468
469
470
471
472
473
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
    sampling_params: Union[List[Dict], Dict] = None
Rin Intachuen's avatar
Rin Intachuen committed
474
475
    # Dummy input embeds for compatibility
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
476
477
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
478
479
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
480

481
    def normalize_batch_and_arguments(self):
482
483
484
485
486
487
488
489
490
        # at least one of text, input_ids, or image should be provided
        if self.text is None and self.input_ids is None and self.image_data is None:
            raise ValueError(
                "At least one of text, input_ids, or image should be provided"
            )

        # text and input_ids cannot be provided at the same time
        if self.text is not None and self.input_ids is not None:
            raise ValueError("text and input_ids cannot be provided at the same time")
491

492
        # Derive the batch size
493
494
495
496
        self.batch_size = 0
        self.is_single = True

        # check the batch size of text
497
        if self.text is not None:
498
499
            if isinstance(self.text, list):
                self.batch_size += len(self.text)
500
            else:
501
502
503
504
505
506
                self.batch_size += 1

        # check the batch size of input_ids
        if self.input_ids is not None:
            if isinstance(self.input_ids[0], list):
                self.batch_size += len(self.input_ids)
507
            else:
508
509
510
511
                self.batch_size += 1

        if self.batch_size > 1:
            self.is_single = False
512

513
        # Fill in default arguments
514
        if self.is_single:
515
516
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
517
            if self.sampling_params is None:
518
                self.sampling_params = {}
519
            self.sampling_params["max_new_tokens"] = 0
520
521
522
523
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
524
525
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
526
            if self.sampling_params is None:
527
528
                self.sampling_params = [{}] * self.batch_size
            for i in range(self.batch_size):
529
                self.sampling_params[i]["max_new_tokens"] = 0
530

531
532
533
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid
534

535
536
537
538
    def __getitem__(self, i):
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
539
            image_data=self.image_data[i] if self.image_data is not None else None,
540
541
542
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
        )
543
544
545


@dataclass
546
class TokenizedEmbeddingReqInput:
547
548
549
550
551
552
    # The request id
    rid: str
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
553
554
    # The image inputs
    image_inputs: dict
555
556
557
558
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams


Lianmin Zheng's avatar
Lianmin Zheng committed
559
560
@dataclass
class BatchTokenIDOut:
561
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
562
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
563
564
565
    # The finish reason
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
566
    decoded_texts: List[str]
567
568
    decode_ids: List[int]
    read_offsets: List[int]
569
    # Only used when `--skip-tokenizer-init` is on
570
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
571
    # Detokenization configs
Lianmin Zheng's avatar
Lianmin Zheng committed
572
    skip_special_tokens: List[bool]
573
    spaces_between_special_tokens: List[bool]
574
    no_stop_trim: List[bool]
575

Lianmin Zheng's avatar
Lianmin Zheng committed
576
577
578
579
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
580
581
    spec_verify_ct: List[int]

Lianmin Zheng's avatar
Lianmin Zheng committed
582
583
584
585
586
587
588
589
590
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
591
592
593
594
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Lianmin Zheng's avatar
Lianmin Zheng committed
595

596
    # Hidden states
597
598
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
599

600
601
602
603
@dataclass
class BatchMultimodalDecodeReq:
    # The request id
    rids: List[str]
604
605
606
607
608
609
    finished_reasons: List[BaseFinishReason]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
610
611


Lianmin Zheng's avatar
Lianmin Zheng committed
612
613
@dataclass
class BatchStrOut:
614
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
615
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
616
617
    # The finish reason
    finished_reasons: List[dict]
618
    # The output decoded strings
619
    output_strs: List[str]
620
621
    # The token ids
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
622
623
624
625
626

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
627
    spec_verify_ct: List[int]
628

Lianmin Zheng's avatar
Lianmin Zheng committed
629
630
631
632
633
634
635
636
637
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
638
639
640
641
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Liangsheng Yin's avatar
Liangsheng Yin committed
642

643
    # Hidden states
644
645
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
646

647
648
649
650
@dataclass
class BatchMultimodalOut:
    # The request id
    rids: List[str]
651
652
653
654
655
656
657
658
659
    # The finish reason
    finished_reasons: List[dict]
    # The outputs
    outputs: List[List[Dict]]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
660
661


662
663
@dataclass
class BatchEmbeddingOut:
664
    # The request id
665
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
666
667
    # The finish reason
    finished_reasons: List[BaseFinishReason]
668
    # The output embedding
669
    embeddings: List[List[float]]
Lianmin Zheng's avatar
Lianmin Zheng committed
670
671
    # Token counts
    prompt_tokens: List[int]
672
    cached_tokens: List[int]
673
674


Liangsheng Yin's avatar
Liangsheng Yin committed
675
@dataclass
676
class FlushCacheReqInput:
Liangsheng Yin's avatar
Liangsheng Yin committed
677
    pass
Cody Yu's avatar
Cody Yu committed
678

679

680
681
682
683
684
@dataclass
class FlushCacheReqOutput:
    success: bool


685
@dataclass
Chayenne's avatar
Chayenne committed
686
class UpdateWeightFromDiskReqInput:
687
688
689
690
691
692
693
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None


@dataclass
Chayenne's avatar
Chayenne committed
694
class UpdateWeightFromDiskReqOutput:
695
696
    success: bool
    message: str
697
698
    # Number of paused requests during weight sync.
    num_paused_requests: Optional[int] = 0
699
700


701
702
703
704
705
706
707
708
709
710
711
712
713
@dataclass
class UpdateWeightsFromDistributedReqInput:
    name: str
    dtype: str
    shape: List[int]


@dataclass
class UpdateWeightsFromDistributedReqOutput:
    success: bool
    message: str


714
715
@dataclass
class UpdateWeightsFromTensorReqInput:
716
717
718
719
720
721
722
723
724
725
726
    """Update model weights from tensor input.

    - Tensors are serialized for transmission
    - Data is structured in JSON for easy transmission over HTTP
    """

    serialized_named_tensors: List[Union[str, bytes]]
    # Optional format specification for loading
    load_format: Optional[str] = None
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
727
728
729
730
731
732
733
734


@dataclass
class UpdateWeightsFromTensorReqOutput:
    success: bool
    message: str


735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
@dataclass
class InitWeightsUpdateGroupReqInput:
    # The master address
    master_address: str
    # The master port
    master_port: int
    # The rank offset
    rank_offset: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_update_group"
    # The backend
    backend: str = "nccl"


@dataclass
class InitWeightsUpdateGroupReqOutput:
    success: bool
    message: str


757
758
759
760
761
762
763
764
765
766
767
@dataclass
class GetWeightsByNameReqInput:
    name: str
    truncate_size: int = 100


@dataclass
class GetWeightsByNameReqOutput:
    parameter: list


768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
@dataclass
class ReleaseMemoryOccupationReqInput:
    pass


@dataclass
class ReleaseMemoryOccupationReqOutput:
    pass


@dataclass
class ResumeMemoryOccupationReqInput:
    pass


@dataclass
class ResumeMemoryOccupationReqOutput:
    pass


788
789
@dataclass
class AbortReq:
790
    # The request id
791
    rid: str
792
793


794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
@dataclass
class GetInternalStateReq:
    pass


@dataclass
class GetInternalStateReqOutput:
    internal_state: Dict[Any, Any]


@dataclass
class SetInternalStateReq:
    server_args: Dict[str, Any]


@dataclass
class SetInternalStateReqOutput:
    updated: bool
    server_args: Dict[str, Any]


@dataclass
class ProfileReqInput:
    # The output directory
    output_dir: Optional[str] = None
    # If set, it profile as many as this number of steps.
    # If it is set, profiling is automatically stopped after this step, and
    # the caller doesn't need to run stop_profile.
    num_steps: Optional[int] = None
823
    activities: Optional[List[Literal["CPU", "GPU", "MEM", "CUDA_PROFILER"]]] = None
824
825
826


class ProfileReqType(Enum):
827
828
    START_PROFILE = 1
    STOP_PROFILE = 2
829
830


831
832
833
834
835
836
class ExpertDistributionReq(Enum):
    START_RECORD = 1
    STOP_RECORD = 2
    DUMP_RECORD = 3


837
838
839
840
841
@dataclass
class ExpertDistributionReqOutput:
    pass


842
843
844
845
846
847
@dataclass
class ProfileReq:
    type: ProfileReqType
    output_dir: Optional[str] = None
    num_steps: Optional[int] = None
    activities: Optional[List[str]] = None
848
849
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
850
    profile_id: Optional[str] = None
851
852
853
854
855
856
857
858


@dataclass
class ProfileReqOutput:
    success: bool
    message: str


859
860
861
@dataclass
class ConfigureLoggingReq:
    log_requests: Optional[bool] = None
862
    log_requests_level: Optional[int] = None
863
864
865
866
    dump_requests_folder: Optional[str] = None
    dump_requests_threshold: Optional[int] = None


867
868
869
@dataclass
class OpenSessionReqInput:
    capacity_of_str_len: int
870
    session_id: Optional[str] = None
871
872
873
874
875
876
877
878
879


@dataclass
class CloseSessionReqInput:
    session_id: str


@dataclass
class OpenSessionReqOutput:
880
881
    session_id: Optional[str]
    success: bool
YAMY's avatar
YAMY committed
882
883


884
885
886
887
888
@dataclass
class HealthCheckOutput:
    pass


YAMY's avatar
YAMY committed
889
890
891
892
893
894
895
896
897
898
899
900
901
902
@dataclass
class Function:
    description: Optional[str] = None
    name: Optional[str] = None
    parameters: Optional[object] = None


@dataclass
class Tool:
    function: Function
    type: Optional[str] = "function"


@dataclass
903
class ParseFunctionCallReq:
YAMY's avatar
YAMY committed
904
905
906
907
908
909
910
    text: str  # The text to parse.
    tools: List[Tool] = field(
        default_factory=list
    )  # A list of available function tools (name, parameters, etc.).
    tool_call_parser: Optional[str] = (
        None  # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
    )
911
912


Xihuai Wang's avatar
Xihuai Wang committed
913
914
915
916
917
918
@dataclass
class SeparateReasoningReqInput:
    text: str  # The text to parse.
    reasoning_parser: str  # Specify the parser type, e.g., "deepseek-r1".


919
920
921
922
@dataclass
class VertexGenerateReqInput:
    instances: List[dict]
    parameters: Optional[dict] = None
923
924
925
926
927
928
929
930
931
932
933
934


@dataclass
class RpcReqInput:
    method: str
    parameters: Optional[Dict] = None


@dataclass
class RpcReqOutput:
    success: bool
    message: str