io_struct.py 45.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""
15
The definition of objects transferred between different
16
processes (TokenizerManager, DetokenizerManager, Scheduler).
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
"""

19
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import uuid
YAMY's avatar
YAMY committed
21
from dataclasses import dataclass, field
22
from enum import Enum
23
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
24

25
from sglang.srt.lora.lora_registry import LoRARef
26
from sglang.srt.managers.schedule_batch import BaseFinishReason
27
from sglang.srt.multimodal.mm_utils import has_valid_data
28
from sglang.srt.sampling.sampling_params import SamplingParams
29
from sglang.srt.utils import ImageData
30

31
# Handle serialization of Image for pydantic
32
33
34
35
if TYPE_CHECKING:
    from PIL.Image import Image
else:
    Image = Any
36

Lianmin Zheng's avatar
Lianmin Zheng committed
37

38
39
40
41
42
43
@dataclass
class SessionParams:
    id: Optional[str] = None
    rid: Optional[str] = None
    offset: Optional[int] = None
    replace: Optional[bool] = None
44
    drop_previous_output: Optional[bool] = None
45
46


47
48
# Type definitions for multimodal input data
# Individual data item types for each modality
49
ImageDataInputItem = Union[Image, str, ImageData, Dict]
50
51
52
53
54
55
56
57
58
59
60
61
AudioDataInputItem = Union[str, Dict]
VideoDataInputItem = Union[str, Dict]
# Union type for any multimodal data item
MultimodalDataInputItem = Union[
    ImageDataInputItem, VideoDataInputItem, AudioDataInputItem
]
# Format types supporting single items, lists, or nested lists for batch processing
MultimodalDataInputFormat = Union[
    List[List[MultimodalDataInputItem]],
    List[MultimodalDataInputItem],
    MultimodalDataInputItem,
]
62
63


Lianmin Zheng's avatar
Lianmin Zheng committed
64
65
@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
66
    # The input prompt. It can be a single prompt or a batch of prompts.
67
    text: Optional[Union[List[str], str]] = None
Rin Intachuen's avatar
Rin Intachuen committed
68
    # The token ids for text; one can specify either text or input_ids
69
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Rin Intachuen's avatar
Rin Intachuen committed
70
71
    # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
72
73
74
75
76
77
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
78
    image_data: Optional[MultimodalDataInputFormat] = None
79
    # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
80
81
82
    video_data: Optional[MultimodalDataInputFormat] = None
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
    audio_data: Optional[MultimodalDataInputFormat] = None
83
    # The sampling_params. See descriptions below.
84
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
85
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
86
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
87
    # Whether to return logprobs.
88
    return_logprob: Optional[Union[List[bool], bool]] = None
89
    # If return logprobs, the start location in the prompt for returning logprobs.
90
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
91
    logprob_start_len: Optional[Union[List[int], int]] = None
92
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
93
    top_logprobs_num: Optional[Union[List[int], int]] = None
94
95
    # If return logprobs, the token ids to return logprob for.
    token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
96
    # Whether to detokenize tokens in text in the returned logprobs.
97
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
98
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
99
    stream: bool = False
100
101
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
102
103
    # Whether to return hidden states
    return_hidden_states: Union[List[bool], bool] = False
104

105
106
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
107
108
109
    # Session info for continual prompting
    session_params: Optional[Union[List[Dict], Dict]] = None

110
    # The path to the LoRA adaptors
111
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
112
113
    # The uid of LoRA adaptors, should be initialized by tokenizer manager
    lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
114

115
116
117
118
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
    custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
119

120
    # For disaggregated inference
121
    bootstrap_host: Optional[Union[List[str], str]] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
122
    bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
123
    bootstrap_room: Optional[Union[List[int], int]] = None
124
    bootstrap_pair_key: Optional[Union[List[str], str]] = None
125

126
127
128
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

129
130
131
    # For background responses (OpenAI responses API)
    background: bool = False

132
133
134
135
136
137
    # Conversation id used for tracking requests
    conversation_id: Optional[str] = None

    # Label for the request
    label: Optional[str] = None

138
139
140
    # Priority for the request
    priority: Optional[int] = None

141
142
143
    # Image gen grpc migration
    return_bytes: bool = False

144
    def contains_mm_input(self) -> bool:
145
146
147
148
149
        return (
            has_valid_data(self.image_data)
            or has_valid_data(self.video_data)
            or has_valid_data(self.audio_data)
        )
150

151
    def normalize_batch_and_arguments(self):
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        """
        Normalize the batch size and arguments for the request.

        This method resolves various input formats and ensures all parameters
        are properly formatted as either single values or batches depending on the input.
        It also handles parallel sampling expansion and sets default values for
        unspecified parameters.

        Raises:
            ValueError: If inputs are not properly specified (e.g., none or all of
                       text, input_ids, input_embeds are provided)
        """
        self._validate_inputs()
        self._determine_batch_size()
        self._handle_parallel_sampling()

        if self.is_single:
            self._normalize_single_inputs()
        else:
            self._normalize_batch_inputs()

    def _validate_inputs(self):
        """Validate that the input configuration is valid."""
Rin Intachuen's avatar
Rin Intachuen committed
175
176
177
178
179
180
        if (
            self.text is None and self.input_ids is None and self.input_embeds is None
        ) or (
            self.text is not None
            and self.input_ids is not None
            and self.input_embeds is not None
181
        ):
Rin Intachuen's avatar
Rin Intachuen committed
182
183
184
            raise ValueError(
                "Either text, input_ids or input_embeds should be provided."
            )
185

186
187
    def _determine_batch_size(self):
        """Determine if this is a single example or a batch and the batch size."""
188
189
190
191
192
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
193
                self.is_single = False
194
                self.batch_size = len(self.text)
Rin Intachuen's avatar
Rin Intachuen committed
195
196
            self.input_embeds = None
        elif self.input_ids is not None:
Yinghai Lu's avatar
Yinghai Lu committed
197
198
            if len(self.input_ids) == 0:
                raise ValueError("input_ids cannot be empty.")
199
200
201
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
202
            else:
203
                self.is_single = False
204
                self.batch_size = len(self.input_ids)
Rin Intachuen's avatar
Rin Intachuen committed
205
206
207
208
209
210
            self.input_embeds = None
        else:
            if isinstance(self.input_embeds[0][0], float):
                self.is_single = True
                self.batch_size = 1
            else:
211
                self.is_single = False
Rin Intachuen's avatar
Rin Intachuen committed
212
                self.batch_size = len(self.input_embeds)
213

214
215
216
    def _handle_parallel_sampling(self):
        """Handle parallel sampling parameters and adjust batch size if needed."""
        # Determine parallel sample count
217
218
        if self.sampling_params is None:
            self.parallel_sample_num = 1
219
            return
220
        elif isinstance(self.sampling_params, dict):
221
222
223
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
224
225
226
227
228
            for sampling_params in self.sampling_params:
                if self.parallel_sample_num != sampling_params.get("n", 1):
                    raise ValueError(
                        "The parallel_sample_num should be the same for all samples in sample params."
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
229

230
        # If using parallel sampling with a single example, convert to batch
231
232
233
234
235
236
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]
237
238
            if self.input_embeds is not None:
                self.input_embeds = [self.input_embeds]
239

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    def _normalize_single_inputs(self):
        """Normalize inputs for a single example."""
        if self.sampling_params is None:
            self.sampling_params = {}
        if self.rid is None:
            self.rid = uuid.uuid4().hex
        if self.return_logprob is None:
            self.return_logprob = False
        if self.logprob_start_len is None:
            self.logprob_start_len = -1
        if self.top_logprobs_num is None:
            self.top_logprobs_num = 0
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = None

    def _normalize_batch_inputs(self):
        """Normalize inputs for a batch of examples, including parallel sampling expansion."""
        # Calculate expanded batch size
        if self.parallel_sample_num == 1:
            num = self.batch_size
Lianmin Zheng's avatar
Lianmin Zheng committed
260
        else:
261
262
263
264
265
            # Expand parallel_sample_num
            num = self.batch_size * self.parallel_sample_num

        # Expand input based on type
        self._expand_inputs(num)
266
        self._normalize_rid(num)
267
268
        self._normalize_lora_paths(num)
        self._normalize_image_data(num)
269
        self._normalize_video_data(num)
270
271
272
273
        self._normalize_audio_data(num)
        self._normalize_sampling_params(num)
        self._normalize_logprob_params(num)
        self._normalize_custom_logit_processor(num)
274
        self._normalize_bootstrap_params(num)
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301

    def _expand_inputs(self, num):
        """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
        if self.text is not None:
            if not isinstance(self.text, list):
                raise ValueError("Text should be a list for batch processing.")
            self.text = self.text * self.parallel_sample_num
        elif self.input_ids is not None:
            if not isinstance(self.input_ids, list) or not isinstance(
                self.input_ids[0], list
            ):
                raise ValueError(
                    "input_ids should be a list of lists for batch processing."
                )
            self.input_ids = self.input_ids * self.parallel_sample_num
        elif self.input_embeds is not None:
            if not isinstance(self.input_embeds, list):
                raise ValueError("input_embeds should be a list for batch processing.")
            self.input_embeds = self.input_embeds * self.parallel_sample_num

    def _normalize_lora_paths(self, num):
        """Normalize LoRA paths for batch processing."""
        if self.lora_path is not None:
            if isinstance(self.lora_path, str):
                self.lora_path = [self.lora_path] * num
            elif isinstance(self.lora_path, list):
                self.lora_path = self.lora_path * self.parallel_sample_num
302
            else:
303
304
305
306
307
308
309
310
311
312
313
                raise ValueError("lora_path should be a list or a string.")

    def _normalize_image_data(self, num):
        """Normalize image data for batch processing."""
        if self.image_data is None:
            self.image_data = [None] * num
        elif not isinstance(self.image_data, list):
            # Single image, convert to list of single-image lists
            self.image_data = [[self.image_data]] * num
            self.modalities = ["image"] * num
        elif isinstance(self.image_data, list):
314
315
316
317
318
            # Handle empty list case - treat as no images
            if len(self.image_data) == 0:
                self.image_data = [None] * num
                return

319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
            if len(self.image_data) != self.batch_size:
                raise ValueError(
                    "The length of image_data should be equal to the batch size."
                )

            self.modalities = []
            if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
                # Already a list of lists, keep as is
                for i in range(len(self.image_data)):
                    if self.image_data[i] is None or self.image_data[i] == [None]:
                        self.modalities.append(None)
                    elif len(self.image_data[i]) == 1:
                        self.modalities.append("image")
                    elif len(self.image_data[i]) > 1:
                        self.modalities.append("multi-images")
334
335
336
                    else:
                        # Ensure len(self.modalities) == len(self.image_data)
                        self.modalities.append(None)
337
                # Expand parallel_sample_num
338
339
                self.image_data = self.image_data * self.parallel_sample_num
                self.modalities = self.modalities * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
340
            else:
341
342
343
344
345
346
                # List of images for a batch, wrap each in a list
                wrapped_images = [[img] for img in self.image_data]
                # Expand for parallel sampling
                self.image_data = wrapped_images * self.parallel_sample_num
                self.modalities = ["image"] * num

347
348
349
350
351
352
353
354
355
    def _normalize_video_data(self, num):
        """Normalize video data for batch processing."""
        if self.video_data is None:
            self.video_data = [None] * num
        elif not isinstance(self.video_data, list):
            self.video_data = [self.video_data] * num
        elif isinstance(self.video_data, list):
            self.video_data = self.video_data * self.parallel_sample_num

356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    def _normalize_audio_data(self, num):
        """Normalize audio data for batch processing."""
        if self.audio_data is None:
            self.audio_data = [None] * num
        elif not isinstance(self.audio_data, list):
            self.audio_data = [self.audio_data] * num
        elif isinstance(self.audio_data, list):
            self.audio_data = self.audio_data * self.parallel_sample_num

    def _normalize_sampling_params(self, num):
        """Normalize sampling parameters for batch processing."""
        if self.sampling_params is None:
            self.sampling_params = [{}] * num
        elif isinstance(self.sampling_params, dict):
            self.sampling_params = [self.sampling_params] * num
        else:  # Already a list
            self.sampling_params = self.sampling_params * self.parallel_sample_num

    def _normalize_rid(self, num):
        """Normalize request IDs for batch processing."""
        if self.rid is None:
            self.rid = [uuid.uuid4().hex for _ in range(num)]
378
379
380
381
        elif isinstance(self.rid, str):
            new_rids = [f"{self.rid}_{i}" for i in range(num)]
            self.rid = new_rids
        elif isinstance(self.rid, list):
382
383
384
            # Note: the length of rid shall be the same as the batch_size,
            # as the rid would be expanded for parallel sampling in tokenizer_manager
            if len(self.rid) != self.batch_size:
385
386
387
388
389
                raise ValueError(
                    "The specified rids length mismatch with the batch_size for batch processing."
                )
        else:
            raise ValueError("The rid should be a string or a list of strings.")
390
391
392
393
394
395
396
397
398
399

    def _normalize_logprob_params(self, num):
        """Normalize logprob-related parameters for batch processing."""

        # Helper function to normalize a parameter
        def normalize_param(param, default_value, param_name):
            if param is None:
                return [default_value] * num
            elif not isinstance(param, list):
                return [param] * num
400
            else:
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
                if self.parallel_sample_num > 1:
                    raise ValueError(
                        f"Cannot use list {param_name} with parallel_sample_num > 1"
                    )
                return param

        # Normalize each logprob parameter
        self.return_logprob = normalize_param(
            self.return_logprob, False, "return_logprob"
        )
        self.logprob_start_len = normalize_param(
            self.logprob_start_len, -1, "logprob_start_len"
        )
        self.top_logprobs_num = normalize_param(
            self.top_logprobs_num, 0, "top_logprobs_num"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
417

418
419
420
421
422
423
424
425
426
427
428
429
430
        # Handle token_ids_logprob specially due to its nested structure
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = [None] * num
        elif not isinstance(self.token_ids_logprob, list):
            self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
        elif not isinstance(self.token_ids_logprob[0], list):
            self.token_ids_logprob = [
                copy.deepcopy(self.token_ids_logprob) for _ in range(num)
            ]
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list token_ids_logprob with parallel_sample_num > 1"
            )
431

432
433
434
435
436
437
438
439
440
441
    def _normalize_custom_logit_processor(self, num):
        """Normalize custom logit processor for batch processing."""
        if self.custom_logit_processor is None:
            self.custom_logit_processor = [None] * num
        elif not isinstance(self.custom_logit_processor, list):
            self.custom_logit_processor = [self.custom_logit_processor] * num
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list custom_logit_processor with parallel_sample_num > 1"
            )
442

443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    def _normalize_bootstrap_params(self, num):
        """Normalize bootstrap parameters for batch processing."""
        # Normalize bootstrap_host
        if self.bootstrap_host is None:
            self.bootstrap_host = [None] * num
        elif not isinstance(self.bootstrap_host, list):
            self.bootstrap_host = [self.bootstrap_host] * num
        elif isinstance(self.bootstrap_host, list):
            self.bootstrap_host = self.bootstrap_host * self.parallel_sample_num

        # Normalize bootstrap_port
        if self.bootstrap_port is None:
            self.bootstrap_port = [None] * num
        elif not isinstance(self.bootstrap_port, list):
            self.bootstrap_port = [self.bootstrap_port] * num
        elif isinstance(self.bootstrap_port, list):
            self.bootstrap_port = self.bootstrap_port * self.parallel_sample_num

        # Normalize bootstrap_room
        if self.bootstrap_room is None:
            self.bootstrap_room = [None] * num
        elif not isinstance(self.bootstrap_room, list):
            self.bootstrap_room = [self.bootstrap_room + i for i in range(num)]
        elif isinstance(self.bootstrap_room, list):
            self.bootstrap_room = self.bootstrap_room * self.parallel_sample_num

        # Normalize bootstrap_pair_key
        if self.bootstrap_pair_key is None:
            self.bootstrap_pair_key = [None] * num
        elif not isinstance(self.bootstrap_pair_key, list):
            self.bootstrap_pair_key = [self.bootstrap_pair_key] * num
        elif isinstance(self.bootstrap_pair_key, list):
            self.bootstrap_pair_key = self.bootstrap_pair_key * self.parallel_sample_num

477
478
    def _validate_session_params(self):
        """Validate that session parameters are properly formatted."""
479
        if self.session_params is not None:
480
            if not isinstance(self.session_params, dict) and not isinstance(
481
                self.session_params[0], dict
482
483
            ):
                raise ValueError("Session params must be a dict or a list of dicts.")
484

485
    def regenerate_rid(self):
486
        """Generate a new request ID and return it."""
487
488
489
490
491
492
493
        self.rid = uuid.uuid4().hex
        return self.rid

    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
494
495
496
            input_embeds=(
                self.input_embeds[i] if self.input_embeds is not None else None
            ),
497
            image_data=self.image_data[i],
498
            video_data=self.video_data[i],
Mick's avatar
Mick committed
499
            audio_data=self.audio_data[i],
500
501
502
503
504
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
505
            token_ids_logprob=self.token_ids_logprob[i],
506
507
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
508
            log_metrics=self.log_metrics,
509
510
511
512
513
            return_hidden_states=(
                self.return_hidden_states[i]
                if isinstance(self.return_hidden_states, list)
                else self.return_hidden_states
            ),
514
            modalities=self.modalities[i] if self.modalities else None,
515
            session_params=self.session_params,
516
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
517
            lora_id=self.lora_id[i] if self.lora_id is not None else None,
518
519
520
521
522
            custom_logit_processor=(
                self.custom_logit_processor[i]
                if self.custom_logit_processor is not None
                else None
            ),
523
            # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
524
525
526
            bootstrap_host=(
                self.bootstrap_host[i] if self.bootstrap_host is not None else None
            ),
527
528
529
            bootstrap_port=(
                self.bootstrap_port[i] if self.bootstrap_port is not None else None
            ),
530
531
532
            bootstrap_room=(
                self.bootstrap_room[i] if self.bootstrap_room is not None else None
            ),
533
534
535
536
537
            bootstrap_pair_key=(
                self.bootstrap_pair_key[i]
                if self.bootstrap_pair_key is not None
                else None
            ),
538
539
540
            data_parallel_rank=(
                self.data_parallel_rank if self.data_parallel_rank is not None else None
            ),
541
542
            conversation_id=self.conversation_id,
            label=self.label,
543
            priority=self.priority,
544
            return_bytes=self.return_bytes,
545
546
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
547
548
549

@dataclass
class TokenizedGenerateReqInput:
550
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
551
    rid: str
552
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
553
    input_text: str
554
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
555
    input_ids: List[int]
Mick's avatar
Mick committed
556
557
    # The multimodal inputs
    mm_inputs: dict
558
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
559
    sampling_params: SamplingParams
560
    # Whether to return the logprobs
561
    return_logprob: bool
562
    # If return logprobs, the start location in the prompt for returning logprobs.
563
    logprob_start_len: int
564
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
565
    top_logprobs_num: int
566
567
    # If return logprobs, the token id to return logprob for
    token_ids_logprob: List[int]
568
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
569
    stream: bool
570
571
    # Whether to return hidden states
    return_hidden_states: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
572

Rin Intachuen's avatar
Rin Intachuen committed
573
574
    # The input embeds
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
575

576
577
    # Session info for continual prompting
    session_params: Optional[SessionParams] = None
578

579
580
581
    # LoRA related
    lora_id: Optional[str] = None  # None means just use the base model

582
583
584
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
585
586
    custom_logit_processor: Optional[str] = None

587
588
    # For disaggregated inference
    bootstrap_host: Optional[str] = None
589
    bootstrap_port: Optional[int] = None
590
    bootstrap_room: Optional[int] = None
591
    bootstrap_pair_key: Optional[str] = None
592

593
594
595
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

596
597
598
    # For dp balance
    dp_balance_id: int = -1

599
600
601
    # Label for the request
    label: Optional[str] = None

602
603
604
    # Priority for the request
    priority: Optional[int] = None

605
606
607
    # Image gen grpc migration
    return_bytes: bool = False

608
609
610
    # tracing context
    trace_context: Optional[Dict] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
611

612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
@dataclass
class BatchTokenizedGenerateReqInput:
    # The batch of tokenized requests
    batch: List[TokenizedGenerateReqInput]

    def __len__(self):
        return len(self.batch)

    def __getitem__(self, i):
        return self.batch[i]

    def __iter__(self):
        return iter(self.batch)


627
628
629
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
woodx's avatar
woodx committed
630
    text: Optional[Union[List[List[str]], List[str], str]] = None
631
632
633
634
635
636
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
637
    image_data: Optional[MultimodalDataInputFormat] = None
638
    # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
639
    video_data: Optional[MultimodalDataInputFormat] = None
640
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
641
    audio_data: Optional[MultimodalDataInputFormat] = None
642
643
644
645
646
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
647
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Rin Intachuen's avatar
Rin Intachuen committed
648
649
    # Dummy input embeds for compatibility
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
650
651
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
652
653
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
woodx's avatar
woodx committed
654
655
    # For cross-encoder requests
    is_cross_encoder_request: bool = False
656

657
658
659
    # For background responses (OpenAI responses API)
    background: bool = False

660
661
662
    # tracing context
    trace_context: Optional[Dict] = None

663
    def normalize_batch_and_arguments(self):
664
665
666
667
668
669
670
671
672
        # at least one of text, input_ids, or image should be provided
        if self.text is None and self.input_ids is None and self.image_data is None:
            raise ValueError(
                "At least one of text, input_ids, or image should be provided"
            )

        # text and input_ids cannot be provided at the same time
        if self.text is not None and self.input_ids is not None:
            raise ValueError("text and input_ids cannot be provided at the same time")
673

674
        # Derive the batch size
675
676
677
678
        self.batch_size = 0
        self.is_single = True

        # check the batch size of text
679
        if self.text is not None:
680
681
            if isinstance(self.text, list):
                self.batch_size += len(self.text)
682
                self.is_single = False
683
            else:
684
685
686
687
688
689
                self.batch_size += 1

        # check the batch size of input_ids
        if self.input_ids is not None:
            if isinstance(self.input_ids[0], list):
                self.batch_size += len(self.input_ids)
690
                self.is_single = False
691
            else:
692
693
                self.batch_size += 1

694
        # Fill in default arguments
695
        if self.is_single:
696
697
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
698
            if self.sampling_params is None:
699
                self.sampling_params = {}
700
            self.sampling_params["max_new_tokens"] = 0
701
702
703
704
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
705
706
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
707
            if self.sampling_params is None:
708
                self.sampling_params = [{}] * self.batch_size
709
710
            elif isinstance(self.sampling_params, dict):
                self.sampling_params = [self.sampling_params] * self.batch_size
711
            for i in range(self.batch_size):
712
                self.sampling_params[i]["max_new_tokens"] = 0
713

714
715
716
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid
717

718
    def contains_mm_input(self) -> bool:
719
720
721
722
723
        return (
            has_valid_data(self.image_data)
            or has_valid_data(self.video_data)
            or has_valid_data(self.audio_data)
        )
724

725
    def __getitem__(self, i):
woodx's avatar
woodx committed
726
727
728
729
730
731
732
733
        if self.is_cross_encoder_request:
            return EmbeddingReqInput(
                text=[self.text[i]] if self.text is not None else None,
                sampling_params=self.sampling_params[i],
                rid=self.rid[i],
                is_cross_encoder_request=True,
            )

734
735
736
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
737
            image_data=self.image_data[i] if self.image_data is not None else None,
738
739
            audio_data=self.audio_data[i] if self.audio_data is not None else None,
            video_data=self.video_data[i] if self.video_data is not None else None,
740
741
742
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
        )
743
744
745


@dataclass
746
class TokenizedEmbeddingReqInput:
747
748
749
750
751
752
    # The request id
    rid: str
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
753
754
    # The image inputs
    image_inputs: dict
woodx's avatar
woodx committed
755
756
    # The token type ids
    token_type_ids: List[int]
757
758
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams
759
760
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None
761
762
    # For dp balance
    dp_balance_id: int = -1
763
764


765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
@dataclass
class BatchTokenizedEmbeddingReqInput:
    # The batch of tokenized embedding requests
    batch: List[TokenizedEmbeddingReqInput]

    def __len__(self):
        return len(self.batch)

    def __getitem__(self, i):
        return self.batch[i]

    def __iter__(self):
        return iter(self.batch)


Lianmin Zheng's avatar
Lianmin Zheng committed
780
781
@dataclass
class BatchTokenIDOut:
782
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
783
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
784
785
786
    # The finish reason
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
787
    decoded_texts: List[str]
788
789
    decode_ids: List[int]
    read_offsets: List[int]
790
    # Only used when `--skip-tokenizer-init` is on
791
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
792
    # Detokenization configs
Lianmin Zheng's avatar
Lianmin Zheng committed
793
    skip_special_tokens: List[bool]
794
    spaces_between_special_tokens: List[bool]
795
    no_stop_trim: List[bool]
796

Lianmin Zheng's avatar
Lianmin Zheng committed
797
798
799
800
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
801
802
    spec_verify_ct: List[int]

Lianmin Zheng's avatar
Lianmin Zheng committed
803
804
805
806
807
808
809
810
811
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
812
813
814
815
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Lianmin Zheng's avatar
Lianmin Zheng committed
816

817
    # Hidden states
818
819
    output_hidden_states: List[List[float]]

820
821
822
823
824
825
    # The information of placeholder tokens (e.g., image token)
    # idx is the index of the token in the prompt after expansion.
    # val is the length of padded tokens after expansion.
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]

Liangsheng Yin's avatar
Liangsheng Yin committed
826

827
828
@dataclass
class BatchMultimodalDecodeReq:
829
830
831
832
833
834
835
836
837
838
839
    decoded_ids: List[int]
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    read_offsets: List[int]
    skip_special_tokens: List[bool]
    spaces_between_special_tokens: List[bool]
    image_resolutions: List[List[int]]
    resize_image_resolutions: List[List[int]]

840
841
    # The request id
    rids: List[str]
842
843
844
845
846
847
    finished_reasons: List[BaseFinishReason]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
848

849
850
851
852
853
854
    # Placeholder token info
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]

    return_bytes: bool = False

855

Lianmin Zheng's avatar
Lianmin Zheng committed
856
857
@dataclass
class BatchStrOut:
858
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
859
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
860
861
    # The finish reason
    finished_reasons: List[dict]
862
    # The output decoded strings
863
    output_strs: List[str]
864
865
    # The token ids
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
866
867
868
869
870

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
871
    spec_verify_ct: List[int]
872

Lianmin Zheng's avatar
Lianmin Zheng committed
873
874
875
876
877
878
879
880
881
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
882
883
884
885
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Liangsheng Yin's avatar
Liangsheng Yin committed
886

887
    # Hidden states
888
889
    output_hidden_states: List[List[float]]

890
891
892
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]

Liangsheng Yin's avatar
Liangsheng Yin committed
893

894
895
896
897
@dataclass
class BatchMultimodalOut:
    # The request id
    rids: List[str]
898
899
    # The finish reason
    finished_reasons: List[dict]
900
    decoded_ids: List[List[int]]
901
    # The outputs
902
903
904
905
906
907
908
    outputs: Union[List[str | bytes], List[List[Dict]]]

    # probability values for input tokens and output tokens
    input_token_logprobs_val: List[List[float]]
    input_token_logprobs_idx: List[List[int]]
    output_token_logprobs_val: List[List[float]]
    output_token_logprobs_idx: List[List[int]]
909
910
911
912
913

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
914

915
916
917
918
919
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]

    return_bytes: List[bool]

920

921
922
@dataclass
class BatchEmbeddingOut:
923
    # The request id
924
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
925
926
    # The finish reason
    finished_reasons: List[BaseFinishReason]
927
    # The output embedding
928
    embeddings: List[List[float]]
Lianmin Zheng's avatar
Lianmin Zheng committed
929
930
    # Token counts
    prompt_tokens: List[int]
931
    cached_tokens: List[int]
932
933
934
    # Placeholder token info
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]
935
936


937
938
939
940
941
942
943
944
945
946
@dataclass
class ClearHiCacheReqInput:
    pass


@dataclass
class ClearHiCacheReqOutput:
    success: bool


Liangsheng Yin's avatar
Liangsheng Yin committed
947
@dataclass
948
class FlushCacheReqInput:
949
    pass
Cody Yu's avatar
Cody Yu committed
950

951

952
953
954
955
956
@dataclass
class FlushCacheReqOutput:
    success: bool


957
@dataclass
Chayenne's avatar
Chayenne committed
958
class UpdateWeightFromDiskReqInput:
959
960
961
962
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None
963
964
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
965
966
    # Optional: Update weight version along with weights
    weight_version: Optional[str] = None
967
968
969
970
971
972
    # Whether to update weights asynchronously
    is_async: bool = False
    # Whether to empty torch cache
    torch_empty_cache: bool = False
    # Whether to keep the scheduler paused after weight update
    keep_pause: bool = False
973
974
975


@dataclass
Chayenne's avatar
Chayenne committed
976
class UpdateWeightFromDiskReqOutput:
977
978
    success: bool
    message: str
979
980
    # Number of paused requests during weight sync.
    num_paused_requests: Optional[int] = 0
981
982


983
984
@dataclass
class UpdateWeightsFromDistributedReqInput:
985
986
987
988
989
990
991
    names: List[str]
    dtypes: List[str]
    shapes: List[List[int]]
    # The group name
    group_name: str = "weight_update_group"
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
992
993
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
994
995
    # Optional: Update weight version along with weights
    weight_version: Optional[str] = None
996
997
998
999
1000
1001
1002
1003


@dataclass
class UpdateWeightsFromDistributedReqOutput:
    success: bool
    message: str


1004
1005
@dataclass
class UpdateWeightsFromTensorReqInput:
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
    """Update model weights from tensor input.

    - Tensors are serialized for transmission
    - Data is structured in JSON for easy transmission over HTTP
    """

    serialized_named_tensors: List[Union[str, bytes]]
    # Optional format specification for loading
    load_format: Optional[str] = None
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
1017
1018
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
1019
1020
    # Optional: Update weight version along with weights
    weight_version: Optional[str] = None
1021
1022
1023
1024
1025
1026
1027
1028


@dataclass
class UpdateWeightsFromTensorReqOutput:
    success: bool
    message: str


1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
@dataclass
class InitWeightsSendGroupForRemoteInstanceReqInput:
    # The master address
    master_address: str
    # The ports for each rank's communication group
    ports: str
    # The rank in the communication group
    group_rank: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_send_group"
    # The backend
    backend: str = "nccl"


@dataclass
class InitWeightsSendGroupForRemoteInstanceReqOutput:
    success: bool
    message: str


@dataclass
class SendWeightsToRemoteInstanceReqInput:
    # The master address
    master_address: str
    # The ports for each rank's communication group
    ports: str
    # The group name
    group_name: str = "weight_send_group"


@dataclass
class SendWeightsToRemoteInstanceReqOutput:
    success: bool
    message: str


1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
@dataclass
class InitWeightsUpdateGroupReqInput:
    # The master address
    master_address: str
    # The master port
    master_port: int
    # The rank offset
    rank_offset: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_update_group"
    # The backend
    backend: str = "nccl"


@dataclass
class InitWeightsUpdateGroupReqOutput:
    success: bool
    message: str


1089
1090
1091
1092
1093
1094
1095
1096
@dataclass
class UpdateWeightVersionReqInput:
    # The new weight version
    new_version: str
    # Whether to abort all running requests before updating
    abort_all_requests: bool = True


1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
@dataclass
class GetWeightsByNameReqInput:
    name: str
    truncate_size: int = 100


@dataclass
class GetWeightsByNameReqOutput:
    parameter: list


1108
1109
@dataclass
class ReleaseMemoryOccupationReqInput:
1110
1111
1112
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
1113
1114
1115
1116


@dataclass
class ReleaseMemoryOccupationReqOutput:
1117
    pass
1118
1119
1120
1121


@dataclass
class ResumeMemoryOccupationReqInput:
1122
1123
1124
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
1125
1126
1127
1128


@dataclass
class ResumeMemoryOccupationReqOutput:
1129
    pass
1130
1131


1132
1133
1134
1135
1136
1137
1138
@dataclass
class SlowDownReqInput:
    forward_sleep_time: Optional[float]


@dataclass
class SlowDownReqOutput:
1139
    pass
1140
1141


1142
1143
@dataclass
class AbortReq:
1144
    # The request id
1145
1146
1147
    rid: str = ""
    # Whether to abort all requests
    abort_all: bool = False
1148
    # The finished reason data
1149
    finished_reason: Optional[Dict[str, Any]] = None
1150
    abort_reason: Optional[str] = None
1151
1152
1153
1154
1155
    # used in MultiTokenzierManager mode
    rids: Optional[Union[List[str], str]] = None

    def __post_init__(self):
        self.rids = self.rid
1156
1157


1158
1159
@dataclass
class GetInternalStateReq:
1160
    pass
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185


@dataclass
class GetInternalStateReqOutput:
    internal_state: Dict[Any, Any]


@dataclass
class SetInternalStateReq:
    server_args: Dict[str, Any]


@dataclass
class SetInternalStateReqOutput:
    updated: bool
    server_args: Dict[str, Any]


@dataclass
class ProfileReqInput:
    # The output directory
    output_dir: Optional[str] = None
    # If set, it profile as many as this number of steps.
    # If it is set, profiling is automatically stopped after this step, and
    # the caller doesn't need to run stop_profile.
1186
    start_step: Optional[int] = None
1187
    num_steps: Optional[int] = None
1188
1189
    activities: Optional[List[str]] = None
    profile_by_stage: bool = False
1190
1191
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
1192
1193
1194


class ProfileReqType(Enum):
1195
1196
    START_PROFILE = 1
    STOP_PROFILE = 2
1197
1198


1199
1200
1201
1202
@dataclass
class ProfileReq:
    type: ProfileReqType
    output_dir: Optional[str] = None
1203
    start_step: Optional[int] = None
1204
1205
    num_steps: Optional[int] = None
    activities: Optional[List[str]] = None
1206
    profile_by_stage: bool = False
1207
1208
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
1209
    profile_id: Optional[str] = None
1210
1211
1212
1213
1214
1215
1216
1217


@dataclass
class ProfileReqOutput:
    success: bool
    message: str


1218
1219
1220
1221
1222
@dataclass
class FreezeGCReq:
    pass


1223
1224
1225
@dataclass
class ConfigureLoggingReq:
    log_requests: Optional[bool] = None
1226
    log_requests_level: Optional[int] = None
1227
1228
    dump_requests_folder: Optional[str] = None
    dump_requests_threshold: Optional[int] = None
1229
    crash_dump_folder: Optional[str] = None
1230
1231


1232
1233
1234
@dataclass
class OpenSessionReqInput:
    capacity_of_str_len: int
1235
    session_id: Optional[str] = None
1236
1237
1238
1239
1240
1241
1242
1243
1244


@dataclass
class CloseSessionReqInput:
    session_id: str


@dataclass
class OpenSessionReqOutput:
1245
1246
    session_id: Optional[str]
    success: bool
YAMY's avatar
YAMY committed
1247
1248


1249
1250
1251
1252
1253
@dataclass
class HealthCheckOutput:
    pass


1254
1255
1256
1257
1258
1259
1260
1261
class ExpertDistributionReq(Enum):
    START_RECORD = 1
    STOP_RECORD = 2
    DUMP_RECORD = 3


@dataclass
class ExpertDistributionReqOutput:
1262
    pass
1263
1264


YAMY's avatar
YAMY committed
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
@dataclass
class Function:
    description: Optional[str] = None
    name: Optional[str] = None
    parameters: Optional[object] = None


@dataclass
class Tool:
    function: Function
    type: Optional[str] = "function"


@dataclass
1279
class ParseFunctionCallReq:
YAMY's avatar
YAMY committed
1280
1281
1282
1283
1284
1285
1286
    text: str  # The text to parse.
    tools: List[Tool] = field(
        default_factory=list
    )  # A list of available function tools (name, parameters, etc.).
    tool_call_parser: Optional[str] = (
        None  # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
    )
1287
1288


Xihuai Wang's avatar
Xihuai Wang committed
1289
1290
1291
1292
1293
1294
@dataclass
class SeparateReasoningReqInput:
    text: str  # The text to parse.
    reasoning_parser: str  # Specify the parser type, e.g., "deepseek-r1".


1295
1296
1297
1298
@dataclass
class VertexGenerateReqInput:
    instances: List[dict]
    parameters: Optional[dict] = None
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310


@dataclass
class RpcReqInput:
    method: str
    parameters: Optional[Dict] = None


@dataclass
class RpcReqOutput:
    success: bool
    message: str
1311
1312
1313
1314
1315
1316
1317
1318


@dataclass
class LoadLoRAAdapterReqInput:
    # The name of the lora module to newly loaded.
    lora_name: str
    # The path of loading.
    lora_path: str
1319
1320
    # Whether to pin the LoRA adapter in memory.
    pinned: bool = False
1321
1322
1323
1324
1325
1326
1327
1328
    # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
    lora_id: Optional[str] = None

    def to_ref(self) -> LoRARef:
        return LoRARef(
            lora_id=self.lora_id,
            lora_name=self.lora_name,
            lora_path=self.lora_path,
1329
            pinned=self.pinned,
1330
        )
1331
1332
1333
1334
1335
1336


@dataclass
class UnloadLoRAAdapterReqInput:
    # The name of lora module to unload.
    lora_name: str
1337
1338
1339
1340
1341
1342
1343
1344
    # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
    lora_id: Optional[str] = None

    def to_ref(self) -> LoRARef:
        return LoRARef(
            lora_id=self.lora_id,
            lora_name=self.lora_name,
        )
1345
1346
1347
1348
1349
1350


@dataclass
class LoRAUpdateResult:
    success: bool
    error_message: Optional[str] = None
1351
    loaded_adapters: Optional[Dict[str, LoRARef]] = None
1352
1353
1354


LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
fzyzcjy's avatar
fzyzcjy committed
1355
1356


1357
1358
1359
1360
1361
1362
1363
@dataclass
class MultiTokenizerRegisterReq:
    rids: Optional[Union[List[str], str]] = None
    ipc_name: Optional[str] = None


@dataclass
1364
class MultiTokenizerWrapper:
1365
1366
1367
1368
    worker_id: int
    obj: Optional[Any] = None


fzyzcjy's avatar
fzyzcjy committed
1369
1370
1371
1372
1373
1374
1375
1376
class BlockReqType(Enum):
    BLOCK = 1
    UNBLOCK = 2


@dataclass
class BlockReqInput:
    type: BlockReqType