io_struct.py 45.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""
15
The definition of objects transferred between different
16
processes (TokenizerManager, DetokenizerManager, Scheduler).
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
"""

19
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import uuid
YAMY's avatar
YAMY committed
21
from dataclasses import dataclass, field
22
from enum import Enum
23
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
24

25
from sglang.srt.lora.lora_registry import LoRARef
26
from sglang.srt.managers.schedule_batch import BaseFinishReason
27
from sglang.srt.multimodal.mm_utils import has_valid_data
28
from sglang.srt.sampling.sampling_params import SamplingParams
29
from sglang.srt.utils import ImageData
30

31
# Handle serialization of Image for pydantic
32
33
34
35
if TYPE_CHECKING:
    from PIL.Image import Image
else:
    Image = Any
36

Lianmin Zheng's avatar
Lianmin Zheng committed
37

38
39
40
41
42
43
@dataclass
class SessionParams:
    id: Optional[str] = None
    rid: Optional[str] = None
    offset: Optional[int] = None
    replace: Optional[bool] = None
44
    drop_previous_output: Optional[bool] = None
45
46


47
48
# Type definitions for multimodal input data
# Individual data item types for each modality
49
ImageDataInputItem = Union[Image, str, ImageData, Dict]
50
51
52
53
54
55
56
57
58
59
60
61
AudioDataInputItem = Union[str, Dict]
VideoDataInputItem = Union[str, Dict]
# Union type for any multimodal data item
MultimodalDataInputItem = Union[
    ImageDataInputItem, VideoDataInputItem, AudioDataInputItem
]
# Format types supporting single items, lists, or nested lists for batch processing
MultimodalDataInputFormat = Union[
    List[List[MultimodalDataInputItem]],
    List[MultimodalDataInputItem],
    MultimodalDataInputItem,
]
62
63


Lianmin Zheng's avatar
Lianmin Zheng committed
64
65
@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
66
    # The input prompt. It can be a single prompt or a batch of prompts.
67
    text: Optional[Union[List[str], str]] = None
Rin Intachuen's avatar
Rin Intachuen committed
68
    # The token ids for text; one can specify either text or input_ids
69
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Rin Intachuen's avatar
Rin Intachuen committed
70
71
    # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
72
73
74
75
76
77
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
78
    image_data: Optional[MultimodalDataInputFormat] = None
79
    # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
80
81
82
    video_data: Optional[MultimodalDataInputFormat] = None
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
    audio_data: Optional[MultimodalDataInputFormat] = None
83
    # The sampling_params. See descriptions below.
84
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
85
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
86
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
87
    # Whether to return logprobs.
88
    return_logprob: Optional[Union[List[bool], bool]] = None
89
    # If return logprobs, the start location in the prompt for returning logprobs.
90
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
91
    logprob_start_len: Optional[Union[List[int], int]] = None
92
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
93
    top_logprobs_num: Optional[Union[List[int], int]] = None
94
95
    # If return logprobs, the token ids to return logprob for.
    token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
96
    # Whether to detokenize tokens in text in the returned logprobs.
97
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
98
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
99
    stream: bool = False
100
101
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
102
103
    # Whether to return hidden states
    return_hidden_states: Union[List[bool], bool] = False
104

105
106
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
107
108
109
    # Session info for continual prompting
    session_params: Optional[Union[List[Dict], Dict]] = None

110
    # The path to the LoRA adaptors
111
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
112
113
    # The uid of LoRA adaptors, should be initialized by tokenizer manager
    lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
114

115
116
117
118
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
    custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
119

120
    # For disaggregated inference
121
    bootstrap_host: Optional[Union[List[str], str]] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
122
    bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
123
    bootstrap_room: Optional[Union[List[int], int]] = None
124
    bootstrap_pair_key: Optional[Union[List[str], str]] = None
125

126
127
128
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

129
130
131
    # For background responses (OpenAI responses API)
    background: bool = False

132
133
134
135
136
137
    # Conversation id used for tracking requests
    conversation_id: Optional[str] = None

    # Label for the request
    label: Optional[str] = None

138
139
140
    # Priority for the request
    priority: Optional[int] = None

141
142
143
    # Image gen grpc migration
    return_bytes: bool = False

144
145
146
    # For customer metric labels
    customer_labels: Optional[Dict[str, str]] = None

147
    def contains_mm_input(self) -> bool:
148
149
150
151
152
        return (
            has_valid_data(self.image_data)
            or has_valid_data(self.video_data)
            or has_valid_data(self.audio_data)
        )
153

154
    def normalize_batch_and_arguments(self):
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        """
        Normalize the batch size and arguments for the request.

        This method resolves various input formats and ensures all parameters
        are properly formatted as either single values or batches depending on the input.
        It also handles parallel sampling expansion and sets default values for
        unspecified parameters.

        Raises:
            ValueError: If inputs are not properly specified (e.g., none or all of
                       text, input_ids, input_embeds are provided)
        """
        self._validate_inputs()
        self._determine_batch_size()
        self._handle_parallel_sampling()

        if self.is_single:
            self._normalize_single_inputs()
        else:
            self._normalize_batch_inputs()

    def _validate_inputs(self):
        """Validate that the input configuration is valid."""
Rin Intachuen's avatar
Rin Intachuen committed
178
179
180
181
182
183
        if (
            self.text is None and self.input_ids is None and self.input_embeds is None
        ) or (
            self.text is not None
            and self.input_ids is not None
            and self.input_embeds is not None
184
        ):
Rin Intachuen's avatar
Rin Intachuen committed
185
186
187
            raise ValueError(
                "Either text, input_ids or input_embeds should be provided."
            )
188

189
190
    def _determine_batch_size(self):
        """Determine if this is a single example or a batch and the batch size."""
191
192
193
194
195
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
196
                self.is_single = False
197
                self.batch_size = len(self.text)
Rin Intachuen's avatar
Rin Intachuen committed
198
199
            self.input_embeds = None
        elif self.input_ids is not None:
Yinghai Lu's avatar
Yinghai Lu committed
200
201
            if len(self.input_ids) == 0:
                raise ValueError("input_ids cannot be empty.")
202
203
204
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
205
            else:
206
                self.is_single = False
207
                self.batch_size = len(self.input_ids)
Rin Intachuen's avatar
Rin Intachuen committed
208
209
210
211
212
213
            self.input_embeds = None
        else:
            if isinstance(self.input_embeds[0][0], float):
                self.is_single = True
                self.batch_size = 1
            else:
214
                self.is_single = False
Rin Intachuen's avatar
Rin Intachuen committed
215
                self.batch_size = len(self.input_embeds)
216

217
218
219
    def _handle_parallel_sampling(self):
        """Handle parallel sampling parameters and adjust batch size if needed."""
        # Determine parallel sample count
220
221
        if self.sampling_params is None:
            self.parallel_sample_num = 1
222
            return
223
        elif isinstance(self.sampling_params, dict):
224
225
226
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
227
228
229
230
231
            for sampling_params in self.sampling_params:
                if self.parallel_sample_num != sampling_params.get("n", 1):
                    raise ValueError(
                        "The parallel_sample_num should be the same for all samples in sample params."
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
232

233
        # If using parallel sampling with a single example, convert to batch
234
235
236
237
238
239
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]
240
241
            if self.input_embeds is not None:
                self.input_embeds = [self.input_embeds]
242

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    def _normalize_single_inputs(self):
        """Normalize inputs for a single example."""
        if self.sampling_params is None:
            self.sampling_params = {}
        if self.rid is None:
            self.rid = uuid.uuid4().hex
        if self.return_logprob is None:
            self.return_logprob = False
        if self.logprob_start_len is None:
            self.logprob_start_len = -1
        if self.top_logprobs_num is None:
            self.top_logprobs_num = 0
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = None

    def _normalize_batch_inputs(self):
        """Normalize inputs for a batch of examples, including parallel sampling expansion."""
        # Calculate expanded batch size
        if self.parallel_sample_num == 1:
            num = self.batch_size
Lianmin Zheng's avatar
Lianmin Zheng committed
263
        else:
264
265
266
267
268
            # Expand parallel_sample_num
            num = self.batch_size * self.parallel_sample_num

        # Expand input based on type
        self._expand_inputs(num)
269
        self._normalize_rid(num)
270
271
        self._normalize_lora_paths(num)
        self._normalize_image_data(num)
272
        self._normalize_video_data(num)
273
274
275
276
        self._normalize_audio_data(num)
        self._normalize_sampling_params(num)
        self._normalize_logprob_params(num)
        self._normalize_custom_logit_processor(num)
277
        self._normalize_bootstrap_params(num)
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304

    def _expand_inputs(self, num):
        """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
        if self.text is not None:
            if not isinstance(self.text, list):
                raise ValueError("Text should be a list for batch processing.")
            self.text = self.text * self.parallel_sample_num
        elif self.input_ids is not None:
            if not isinstance(self.input_ids, list) or not isinstance(
                self.input_ids[0], list
            ):
                raise ValueError(
                    "input_ids should be a list of lists for batch processing."
                )
            self.input_ids = self.input_ids * self.parallel_sample_num
        elif self.input_embeds is not None:
            if not isinstance(self.input_embeds, list):
                raise ValueError("input_embeds should be a list for batch processing.")
            self.input_embeds = self.input_embeds * self.parallel_sample_num

    def _normalize_lora_paths(self, num):
        """Normalize LoRA paths for batch processing."""
        if self.lora_path is not None:
            if isinstance(self.lora_path, str):
                self.lora_path = [self.lora_path] * num
            elif isinstance(self.lora_path, list):
                self.lora_path = self.lora_path * self.parallel_sample_num
305
            else:
306
307
308
309
310
311
312
313
314
315
316
                raise ValueError("lora_path should be a list or a string.")

    def _normalize_image_data(self, num):
        """Normalize image data for batch processing."""
        if self.image_data is None:
            self.image_data = [None] * num
        elif not isinstance(self.image_data, list):
            # Single image, convert to list of single-image lists
            self.image_data = [[self.image_data]] * num
            self.modalities = ["image"] * num
        elif isinstance(self.image_data, list):
317
318
319
320
321
            # Handle empty list case - treat as no images
            if len(self.image_data) == 0:
                self.image_data = [None] * num
                return

322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
            if len(self.image_data) != self.batch_size:
                raise ValueError(
                    "The length of image_data should be equal to the batch size."
                )

            self.modalities = []
            if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
                # Already a list of lists, keep as is
                for i in range(len(self.image_data)):
                    if self.image_data[i] is None or self.image_data[i] == [None]:
                        self.modalities.append(None)
                    elif len(self.image_data[i]) == 1:
                        self.modalities.append("image")
                    elif len(self.image_data[i]) > 1:
                        self.modalities.append("multi-images")
337
338
339
                    else:
                        # Ensure len(self.modalities) == len(self.image_data)
                        self.modalities.append(None)
340
                # Expand parallel_sample_num
341
342
                self.image_data = self.image_data * self.parallel_sample_num
                self.modalities = self.modalities * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
343
            else:
344
345
346
347
348
349
                # List of images for a batch, wrap each in a list
                wrapped_images = [[img] for img in self.image_data]
                # Expand for parallel sampling
                self.image_data = wrapped_images * self.parallel_sample_num
                self.modalities = ["image"] * num

350
351
352
353
354
355
356
357
358
    def _normalize_video_data(self, num):
        """Normalize video data for batch processing."""
        if self.video_data is None:
            self.video_data = [None] * num
        elif not isinstance(self.video_data, list):
            self.video_data = [self.video_data] * num
        elif isinstance(self.video_data, list):
            self.video_data = self.video_data * self.parallel_sample_num

359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    def _normalize_audio_data(self, num):
        """Normalize audio data for batch processing."""
        if self.audio_data is None:
            self.audio_data = [None] * num
        elif not isinstance(self.audio_data, list):
            self.audio_data = [self.audio_data] * num
        elif isinstance(self.audio_data, list):
            self.audio_data = self.audio_data * self.parallel_sample_num

    def _normalize_sampling_params(self, num):
        """Normalize sampling parameters for batch processing."""
        if self.sampling_params is None:
            self.sampling_params = [{}] * num
        elif isinstance(self.sampling_params, dict):
            self.sampling_params = [self.sampling_params] * num
        else:  # Already a list
            self.sampling_params = self.sampling_params * self.parallel_sample_num

    def _normalize_rid(self, num):
        """Normalize request IDs for batch processing."""
        if self.rid is None:
            self.rid = [uuid.uuid4().hex for _ in range(num)]
381
382
383
384
        elif isinstance(self.rid, str):
            new_rids = [f"{self.rid}_{i}" for i in range(num)]
            self.rid = new_rids
        elif isinstance(self.rid, list):
385
386
387
            # Note: the length of rid shall be the same as the batch_size,
            # as the rid would be expanded for parallel sampling in tokenizer_manager
            if len(self.rid) != self.batch_size:
388
389
390
391
392
                raise ValueError(
                    "The specified rids length mismatch with the batch_size for batch processing."
                )
        else:
            raise ValueError("The rid should be a string or a list of strings.")
393
394
395
396
397
398
399
400
401
402

    def _normalize_logprob_params(self, num):
        """Normalize logprob-related parameters for batch processing."""

        # Helper function to normalize a parameter
        def normalize_param(param, default_value, param_name):
            if param is None:
                return [default_value] * num
            elif not isinstance(param, list):
                return [param] * num
403
            else:
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
                if self.parallel_sample_num > 1:
                    raise ValueError(
                        f"Cannot use list {param_name} with parallel_sample_num > 1"
                    )
                return param

        # Normalize each logprob parameter
        self.return_logprob = normalize_param(
            self.return_logprob, False, "return_logprob"
        )
        self.logprob_start_len = normalize_param(
            self.logprob_start_len, -1, "logprob_start_len"
        )
        self.top_logprobs_num = normalize_param(
            self.top_logprobs_num, 0, "top_logprobs_num"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
420

421
422
423
424
425
426
427
428
429
430
431
432
433
        # Handle token_ids_logprob specially due to its nested structure
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = [None] * num
        elif not isinstance(self.token_ids_logprob, list):
            self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
        elif not isinstance(self.token_ids_logprob[0], list):
            self.token_ids_logprob = [
                copy.deepcopy(self.token_ids_logprob) for _ in range(num)
            ]
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list token_ids_logprob with parallel_sample_num > 1"
            )
434

435
436
437
438
439
440
441
442
443
444
    def _normalize_custom_logit_processor(self, num):
        """Normalize custom logit processor for batch processing."""
        if self.custom_logit_processor is None:
            self.custom_logit_processor = [None] * num
        elif not isinstance(self.custom_logit_processor, list):
            self.custom_logit_processor = [self.custom_logit_processor] * num
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list custom_logit_processor with parallel_sample_num > 1"
            )
445

446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
    def _normalize_bootstrap_params(self, num):
        """Normalize bootstrap parameters for batch processing."""
        # Normalize bootstrap_host
        if self.bootstrap_host is None:
            self.bootstrap_host = [None] * num
        elif not isinstance(self.bootstrap_host, list):
            self.bootstrap_host = [self.bootstrap_host] * num
        elif isinstance(self.bootstrap_host, list):
            self.bootstrap_host = self.bootstrap_host * self.parallel_sample_num

        # Normalize bootstrap_port
        if self.bootstrap_port is None:
            self.bootstrap_port = [None] * num
        elif not isinstance(self.bootstrap_port, list):
            self.bootstrap_port = [self.bootstrap_port] * num
        elif isinstance(self.bootstrap_port, list):
            self.bootstrap_port = self.bootstrap_port * self.parallel_sample_num

        # Normalize bootstrap_room
        if self.bootstrap_room is None:
            self.bootstrap_room = [None] * num
        elif not isinstance(self.bootstrap_room, list):
            self.bootstrap_room = [self.bootstrap_room + i for i in range(num)]
        elif isinstance(self.bootstrap_room, list):
            self.bootstrap_room = self.bootstrap_room * self.parallel_sample_num

        # Normalize bootstrap_pair_key
        if self.bootstrap_pair_key is None:
            self.bootstrap_pair_key = [None] * num
        elif not isinstance(self.bootstrap_pair_key, list):
            self.bootstrap_pair_key = [self.bootstrap_pair_key] * num
        elif isinstance(self.bootstrap_pair_key, list):
            self.bootstrap_pair_key = self.bootstrap_pair_key * self.parallel_sample_num

480
481
    def _validate_session_params(self):
        """Validate that session parameters are properly formatted."""
482
        if self.session_params is not None:
483
            if not isinstance(self.session_params, dict) and not isinstance(
484
                self.session_params[0], dict
485
486
            ):
                raise ValueError("Session params must be a dict or a list of dicts.")
487

488
    def regenerate_rid(self):
489
        """Generate a new request ID and return it."""
490
491
492
493
494
495
496
        self.rid = uuid.uuid4().hex
        return self.rid

    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
497
498
499
            input_embeds=(
                self.input_embeds[i] if self.input_embeds is not None else None
            ),
500
            image_data=self.image_data[i],
501
            video_data=self.video_data[i],
Mick's avatar
Mick committed
502
            audio_data=self.audio_data[i],
503
504
505
506
507
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
508
            token_ids_logprob=self.token_ids_logprob[i],
509
510
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
511
            log_metrics=self.log_metrics,
512
513
514
515
516
            return_hidden_states=(
                self.return_hidden_states[i]
                if isinstance(self.return_hidden_states, list)
                else self.return_hidden_states
            ),
517
            modalities=self.modalities[i] if self.modalities else None,
518
            session_params=self.session_params,
519
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
520
            lora_id=self.lora_id[i] if self.lora_id is not None else None,
521
522
523
524
525
            custom_logit_processor=(
                self.custom_logit_processor[i]
                if self.custom_logit_processor is not None
                else None
            ),
526
            # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
527
528
529
            bootstrap_host=(
                self.bootstrap_host[i] if self.bootstrap_host is not None else None
            ),
530
531
532
            bootstrap_port=(
                self.bootstrap_port[i] if self.bootstrap_port is not None else None
            ),
533
534
535
            bootstrap_room=(
                self.bootstrap_room[i] if self.bootstrap_room is not None else None
            ),
536
537
538
539
540
            bootstrap_pair_key=(
                self.bootstrap_pair_key[i]
                if self.bootstrap_pair_key is not None
                else None
            ),
541
542
543
            data_parallel_rank=(
                self.data_parallel_rank if self.data_parallel_rank is not None else None
            ),
544
545
            conversation_id=self.conversation_id,
            label=self.label,
546
            priority=self.priority,
547
            return_bytes=self.return_bytes,
548
549
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
550
551
552

@dataclass
class TokenizedGenerateReqInput:
553
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
554
    rid: str
555
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
556
    input_text: str
557
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
558
    input_ids: List[int]
Mick's avatar
Mick committed
559
560
    # The multimodal inputs
    mm_inputs: dict
561
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
562
    sampling_params: SamplingParams
563
    # Whether to return the logprobs
564
    return_logprob: bool
565
    # If return logprobs, the start location in the prompt for returning logprobs.
566
    logprob_start_len: int
567
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
568
    top_logprobs_num: int
569
570
    # If return logprobs, the token id to return logprob for
    token_ids_logprob: List[int]
571
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
572
    stream: bool
573
574
    # Whether to return hidden states
    return_hidden_states: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
575

Rin Intachuen's avatar
Rin Intachuen committed
576
577
    # The input embeds
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
578

579
580
    # Session info for continual prompting
    session_params: Optional[SessionParams] = None
581

582
583
584
    # LoRA related
    lora_id: Optional[str] = None  # None means just use the base model

585
586
587
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
588
589
    custom_logit_processor: Optional[str] = None

590
591
    # For disaggregated inference
    bootstrap_host: Optional[str] = None
592
    bootstrap_port: Optional[int] = None
593
    bootstrap_room: Optional[int] = None
594
    bootstrap_pair_key: Optional[str] = None
595

596
597
598
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

599
600
601
    # For dp balance
    dp_balance_id: int = -1

602
603
604
    # Label for the request
    label: Optional[str] = None

605
606
607
    # Priority for the request
    priority: Optional[int] = None

608
609
610
    # Image gen grpc migration
    return_bytes: bool = False

611
612
613
    # tracing context
    trace_context: Optional[Dict] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
614

615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
@dataclass
class BatchTokenizedGenerateReqInput:
    # The batch of tokenized requests
    batch: List[TokenizedGenerateReqInput]

    def __len__(self):
        return len(self.batch)

    def __getitem__(self, i):
        return self.batch[i]

    def __iter__(self):
        return iter(self.batch)


630
631
632
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
woodx's avatar
woodx committed
633
    text: Optional[Union[List[List[str]], List[str], str]] = None
634
635
636
637
638
639
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
640
    image_data: Optional[MultimodalDataInputFormat] = None
641
    # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
642
    video_data: Optional[MultimodalDataInputFormat] = None
643
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
644
    audio_data: Optional[MultimodalDataInputFormat] = None
645
646
647
648
649
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
650
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Rin Intachuen's avatar
Rin Intachuen committed
651
652
    # Dummy input embeds for compatibility
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
653
654
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
655
656
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
woodx's avatar
woodx committed
657
658
    # For cross-encoder requests
    is_cross_encoder_request: bool = False
659

660
661
662
    # For background responses (OpenAI responses API)
    background: bool = False

663
664
665
    # tracing context
    trace_context: Optional[Dict] = None

666
    def normalize_batch_and_arguments(self):
667
668
669
670
671
672
673
674
675
        # at least one of text, input_ids, or image should be provided
        if self.text is None and self.input_ids is None and self.image_data is None:
            raise ValueError(
                "At least one of text, input_ids, or image should be provided"
            )

        # text and input_ids cannot be provided at the same time
        if self.text is not None and self.input_ids is not None:
            raise ValueError("text and input_ids cannot be provided at the same time")
676

677
        # Derive the batch size
678
679
680
681
        self.batch_size = 0
        self.is_single = True

        # check the batch size of text
682
        if self.text is not None:
683
684
            if isinstance(self.text, list):
                self.batch_size += len(self.text)
685
                self.is_single = False
686
            else:
687
688
689
690
691
692
                self.batch_size += 1

        # check the batch size of input_ids
        if self.input_ids is not None:
            if isinstance(self.input_ids[0], list):
                self.batch_size += len(self.input_ids)
693
                self.is_single = False
694
            else:
695
696
                self.batch_size += 1

697
        # Fill in default arguments
698
        if self.is_single:
699
700
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
701
            if self.sampling_params is None:
702
                self.sampling_params = {}
703
            self.sampling_params["max_new_tokens"] = 0
704
705
706
707
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
708
709
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
710
            if self.sampling_params is None:
711
                self.sampling_params = [{}] * self.batch_size
712
713
            elif isinstance(self.sampling_params, dict):
                self.sampling_params = [self.sampling_params] * self.batch_size
714
            for i in range(self.batch_size):
715
                self.sampling_params[i]["max_new_tokens"] = 0
716

717
718
719
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid
720

721
    def contains_mm_input(self) -> bool:
722
723
724
725
726
        return (
            has_valid_data(self.image_data)
            or has_valid_data(self.video_data)
            or has_valid_data(self.audio_data)
        )
727

728
    def __getitem__(self, i):
woodx's avatar
woodx committed
729
730
731
732
733
734
735
736
        if self.is_cross_encoder_request:
            return EmbeddingReqInput(
                text=[self.text[i]] if self.text is not None else None,
                sampling_params=self.sampling_params[i],
                rid=self.rid[i],
                is_cross_encoder_request=True,
            )

737
738
739
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
740
            image_data=self.image_data[i] if self.image_data is not None else None,
741
742
            audio_data=self.audio_data[i] if self.audio_data is not None else None,
            video_data=self.video_data[i] if self.video_data is not None else None,
743
744
745
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
        )
746
747
748


@dataclass
749
class TokenizedEmbeddingReqInput:
750
751
752
753
754
755
    # The request id
    rid: str
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
756
757
    # The image inputs
    image_inputs: dict
woodx's avatar
woodx committed
758
759
    # The token type ids
    token_type_ids: List[int]
760
761
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams
762
763
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None
764
765
    # For dp balance
    dp_balance_id: int = -1
766
767


768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
@dataclass
class BatchTokenizedEmbeddingReqInput:
    # The batch of tokenized embedding requests
    batch: List[TokenizedEmbeddingReqInput]

    def __len__(self):
        return len(self.batch)

    def __getitem__(self, i):
        return self.batch[i]

    def __iter__(self):
        return iter(self.batch)


Lianmin Zheng's avatar
Lianmin Zheng committed
783
784
@dataclass
class BatchTokenIDOut:
785
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
786
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
787
788
789
    # The finish reason
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
790
    decoded_texts: List[str]
791
792
    decode_ids: List[int]
    read_offsets: List[int]
793
    # Only used when `--skip-tokenizer-init` is on
794
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
795
    # Detokenization configs
Lianmin Zheng's avatar
Lianmin Zheng committed
796
    skip_special_tokens: List[bool]
797
    spaces_between_special_tokens: List[bool]
798
    no_stop_trim: List[bool]
799

Lianmin Zheng's avatar
Lianmin Zheng committed
800
801
802
803
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
804
805
    spec_verify_ct: List[int]

Lianmin Zheng's avatar
Lianmin Zheng committed
806
807
808
809
810
811
812
813
814
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
815
816
817
818
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Lianmin Zheng's avatar
Lianmin Zheng committed
819

820
    # Hidden states
821
822
    output_hidden_states: List[List[float]]

823
824
825
826
827
828
    # The information of placeholder tokens (e.g., image token)
    # idx is the index of the token in the prompt after expansion.
    # val is the length of padded tokens after expansion.
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]

Liangsheng Yin's avatar
Liangsheng Yin committed
829

830
831
@dataclass
class BatchMultimodalDecodeReq:
832
833
834
835
836
837
838
839
840
841
842
    decoded_ids: List[int]
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    read_offsets: List[int]
    skip_special_tokens: List[bool]
    spaces_between_special_tokens: List[bool]
    image_resolutions: List[List[int]]
    resize_image_resolutions: List[List[int]]

843
844
    # The request id
    rids: List[str]
845
846
847
848
849
850
    finished_reasons: List[BaseFinishReason]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
851

852
853
854
855
856
857
    # Placeholder token info
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]

    return_bytes: bool = False

858

Lianmin Zheng's avatar
Lianmin Zheng committed
859
860
@dataclass
class BatchStrOut:
861
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
862
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
863
864
    # The finish reason
    finished_reasons: List[dict]
865
    # The output decoded strings
866
    output_strs: List[str]
867
868
    # The token ids
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
869
870
871
872
873

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
874
    spec_verify_ct: List[int]
875

Lianmin Zheng's avatar
Lianmin Zheng committed
876
877
878
879
880
881
882
883
884
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
885
886
887
888
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Liangsheng Yin's avatar
Liangsheng Yin committed
889

890
    # Hidden states
891
892
    output_hidden_states: List[List[float]]

893
894
895
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]

Liangsheng Yin's avatar
Liangsheng Yin committed
896

897
898
899
900
@dataclass
class BatchMultimodalOut:
    # The request id
    rids: List[str]
901
902
    # The finish reason
    finished_reasons: List[dict]
903
    decoded_ids: List[List[int]]
904
    # The outputs
905
906
907
908
909
910
911
    outputs: Union[List[str | bytes], List[List[Dict]]]

    # probability values for input tokens and output tokens
    input_token_logprobs_val: List[List[float]]
    input_token_logprobs_idx: List[List[int]]
    output_token_logprobs_val: List[List[float]]
    output_token_logprobs_idx: List[List[int]]
912
913
914
915
916

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
917

918
919
920
921
922
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]

    return_bytes: List[bool]

923

924
925
@dataclass
class BatchEmbeddingOut:
926
    # The request id
927
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
928
929
    # The finish reason
    finished_reasons: List[BaseFinishReason]
930
    # The output embedding
931
    embeddings: List[List[float]]
Lianmin Zheng's avatar
Lianmin Zheng committed
932
933
    # Token counts
    prompt_tokens: List[int]
934
    cached_tokens: List[int]
935
936
937
    # Placeholder token info
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]
938
939


940
941
942
943
944
945
946
947
948
949
@dataclass
class ClearHiCacheReqInput:
    pass


@dataclass
class ClearHiCacheReqOutput:
    success: bool


Liangsheng Yin's avatar
Liangsheng Yin committed
950
@dataclass
951
class FlushCacheReqInput:
952
    pass
Cody Yu's avatar
Cody Yu committed
953

954

955
956
957
958
959
@dataclass
class FlushCacheReqOutput:
    success: bool


960
@dataclass
Chayenne's avatar
Chayenne committed
961
class UpdateWeightFromDiskReqInput:
962
963
964
965
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None
966
967
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
968
969
    # Optional: Update weight version along with weights
    weight_version: Optional[str] = None
970
971
972
973
974
975
    # Whether to update weights asynchronously
    is_async: bool = False
    # Whether to empty torch cache
    torch_empty_cache: bool = False
    # Whether to keep the scheduler paused after weight update
    keep_pause: bool = False
976
977
978


@dataclass
Chayenne's avatar
Chayenne committed
979
class UpdateWeightFromDiskReqOutput:
980
981
    success: bool
    message: str
982
983
    # Number of paused requests during weight sync.
    num_paused_requests: Optional[int] = 0
984
985


986
987
@dataclass
class UpdateWeightsFromDistributedReqInput:
988
989
990
991
992
993
994
    names: List[str]
    dtypes: List[str]
    shapes: List[List[int]]
    # The group name
    group_name: str = "weight_update_group"
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
995
996
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
997
998
    # Optional: Update weight version along with weights
    weight_version: Optional[str] = None
999
1000
1001
1002
1003
1004
1005
1006


@dataclass
class UpdateWeightsFromDistributedReqOutput:
    success: bool
    message: str


1007
1008
@dataclass
class UpdateWeightsFromTensorReqInput:
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
    """Update model weights from tensor input.

    - Tensors are serialized for transmission
    - Data is structured in JSON for easy transmission over HTTP
    """

    serialized_named_tensors: List[Union[str, bytes]]
    # Optional format specification for loading
    load_format: Optional[str] = None
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
1020
1021
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
1022
1023
    # Optional: Update weight version along with weights
    weight_version: Optional[str] = None
1024
1025
1026
1027
1028
1029
1030
1031


@dataclass
class UpdateWeightsFromTensorReqOutput:
    success: bool
    message: str


1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
@dataclass
class InitWeightsSendGroupForRemoteInstanceReqInput:
    # The master address
    master_address: str
    # The ports for each rank's communication group
    ports: str
    # The rank in the communication group
    group_rank: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_send_group"
    # The backend
    backend: str = "nccl"


@dataclass
class InitWeightsSendGroupForRemoteInstanceReqOutput:
    success: bool
    message: str


@dataclass
class SendWeightsToRemoteInstanceReqInput:
    # The master address
    master_address: str
    # The ports for each rank's communication group
    ports: str
    # The group name
    group_name: str = "weight_send_group"


@dataclass
class SendWeightsToRemoteInstanceReqOutput:
    success: bool
    message: str


1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
@dataclass
class InitWeightsUpdateGroupReqInput:
    # The master address
    master_address: str
    # The master port
    master_port: int
    # The rank offset
    rank_offset: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_update_group"
    # The backend
    backend: str = "nccl"


@dataclass
class InitWeightsUpdateGroupReqOutput:
    success: bool
    message: str


1092
1093
1094
1095
1096
1097
1098
1099
@dataclass
class UpdateWeightVersionReqInput:
    # The new weight version
    new_version: str
    # Whether to abort all running requests before updating
    abort_all_requests: bool = True


1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
@dataclass
class GetWeightsByNameReqInput:
    name: str
    truncate_size: int = 100


@dataclass
class GetWeightsByNameReqOutput:
    parameter: list


1111
1112
@dataclass
class ReleaseMemoryOccupationReqInput:
1113
1114
1115
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
1116
1117
1118
1119


@dataclass
class ReleaseMemoryOccupationReqOutput:
1120
    pass
1121
1122
1123
1124


@dataclass
class ResumeMemoryOccupationReqInput:
1125
1126
1127
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
1128
1129
1130
1131


@dataclass
class ResumeMemoryOccupationReqOutput:
1132
    pass
1133
1134


1135
1136
1137
1138
1139
1140
1141
@dataclass
class SlowDownReqInput:
    forward_sleep_time: Optional[float]


@dataclass
class SlowDownReqOutput:
1142
    pass
1143
1144


1145
1146
@dataclass
class AbortReq:
1147
    # The request id
1148
1149
1150
    rid: str = ""
    # Whether to abort all requests
    abort_all: bool = False
1151
    # The finished reason data
1152
    finished_reason: Optional[Dict[str, Any]] = None
1153
    abort_reason: Optional[str] = None
1154
1155
1156
1157
1158
    # used in MultiTokenzierManager mode
    rids: Optional[Union[List[str], str]] = None

    def __post_init__(self):
        self.rids = self.rid
1159
1160


1161
1162
@dataclass
class GetInternalStateReq:
1163
    pass
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188


@dataclass
class GetInternalStateReqOutput:
    internal_state: Dict[Any, Any]


@dataclass
class SetInternalStateReq:
    server_args: Dict[str, Any]


@dataclass
class SetInternalStateReqOutput:
    updated: bool
    server_args: Dict[str, Any]


@dataclass
class ProfileReqInput:
    # The output directory
    output_dir: Optional[str] = None
    # If set, it profile as many as this number of steps.
    # If it is set, profiling is automatically stopped after this step, and
    # the caller doesn't need to run stop_profile.
1189
    start_step: Optional[int] = None
1190
    num_steps: Optional[int] = None
1191
1192
    activities: Optional[List[str]] = None
    profile_by_stage: bool = False
1193
1194
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
1195
1196
1197


class ProfileReqType(Enum):
1198
1199
    START_PROFILE = 1
    STOP_PROFILE = 2
1200
1201


1202
1203
1204
1205
@dataclass
class ProfileReq:
    type: ProfileReqType
    output_dir: Optional[str] = None
1206
    start_step: Optional[int] = None
1207
1208
    num_steps: Optional[int] = None
    activities: Optional[List[str]] = None
1209
    profile_by_stage: bool = False
1210
1211
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
1212
    profile_id: Optional[str] = None
1213
1214
1215
1216
1217
1218
1219
1220


@dataclass
class ProfileReqOutput:
    success: bool
    message: str


1221
1222
1223
1224
1225
@dataclass
class FreezeGCReq:
    pass


1226
1227
1228
@dataclass
class ConfigureLoggingReq:
    log_requests: Optional[bool] = None
1229
    log_requests_level: Optional[int] = None
1230
1231
    dump_requests_folder: Optional[str] = None
    dump_requests_threshold: Optional[int] = None
1232
    crash_dump_folder: Optional[str] = None
1233
1234


1235
1236
1237
@dataclass
class OpenSessionReqInput:
    capacity_of_str_len: int
1238
    session_id: Optional[str] = None
1239
1240
1241
1242
1243
1244
1245
1246
1247


@dataclass
class CloseSessionReqInput:
    session_id: str


@dataclass
class OpenSessionReqOutput:
1248
1249
    session_id: Optional[str]
    success: bool
YAMY's avatar
YAMY committed
1250
1251


1252
1253
1254
1255
1256
@dataclass
class HealthCheckOutput:
    pass


1257
1258
1259
1260
1261
1262
1263
1264
class ExpertDistributionReq(Enum):
    START_RECORD = 1
    STOP_RECORD = 2
    DUMP_RECORD = 3


@dataclass
class ExpertDistributionReqOutput:
1265
    pass
1266
1267


YAMY's avatar
YAMY committed
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
@dataclass
class Function:
    description: Optional[str] = None
    name: Optional[str] = None
    parameters: Optional[object] = None


@dataclass
class Tool:
    function: Function
    type: Optional[str] = "function"


@dataclass
1282
class ParseFunctionCallReq:
YAMY's avatar
YAMY committed
1283
1284
1285
1286
1287
1288
1289
    text: str  # The text to parse.
    tools: List[Tool] = field(
        default_factory=list
    )  # A list of available function tools (name, parameters, etc.).
    tool_call_parser: Optional[str] = (
        None  # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
    )
1290
1291


Xihuai Wang's avatar
Xihuai Wang committed
1292
1293
1294
1295
1296
1297
@dataclass
class SeparateReasoningReqInput:
    text: str  # The text to parse.
    reasoning_parser: str  # Specify the parser type, e.g., "deepseek-r1".


1298
1299
1300
1301
@dataclass
class VertexGenerateReqInput:
    instances: List[dict]
    parameters: Optional[dict] = None
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313


@dataclass
class RpcReqInput:
    method: str
    parameters: Optional[Dict] = None


@dataclass
class RpcReqOutput:
    success: bool
    message: str
1314
1315
1316
1317
1318
1319
1320
1321


@dataclass
class LoadLoRAAdapterReqInput:
    # The name of the lora module to newly loaded.
    lora_name: str
    # The path of loading.
    lora_path: str
1322
1323
    # Whether to pin the LoRA adapter in memory.
    pinned: bool = False
1324
1325
1326
1327
1328
1329
1330
1331
    # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
    lora_id: Optional[str] = None

    def to_ref(self) -> LoRARef:
        return LoRARef(
            lora_id=self.lora_id,
            lora_name=self.lora_name,
            lora_path=self.lora_path,
1332
            pinned=self.pinned,
1333
        )
1334
1335
1336
1337
1338
1339


@dataclass
class UnloadLoRAAdapterReqInput:
    # The name of lora module to unload.
    lora_name: str
1340
1341
1342
1343
1344
1345
1346
1347
    # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
    lora_id: Optional[str] = None

    def to_ref(self) -> LoRARef:
        return LoRARef(
            lora_id=self.lora_id,
            lora_name=self.lora_name,
        )
1348
1349
1350
1351
1352
1353


@dataclass
class LoRAUpdateResult:
    success: bool
    error_message: Optional[str] = None
1354
    loaded_adapters: Optional[Dict[str, LoRARef]] = None
1355
1356
1357


LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
fzyzcjy's avatar
fzyzcjy committed
1358
1359


1360
1361
1362
1363
1364
1365
1366
@dataclass
class MultiTokenizerRegisterReq:
    rids: Optional[Union[List[str], str]] = None
    ipc_name: Optional[str] = None


@dataclass
1367
class MultiTokenizerWrapper:
1368
1369
1370
1371
    worker_id: int
    obj: Optional[Any] = None


fzyzcjy's avatar
fzyzcjy committed
1372
1373
1374
1375
1376
1377
1378
1379
class BlockReqType(Enum):
    BLOCK = 1
    UNBLOCK = 2


@dataclass
class BlockReqInput:
    type: BlockReqType
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397


@dataclass
class GetLoadReqInput:
    pass


@dataclass
class GetLoadReqOutput:
    dp_rank: int
    num_reqs: int
    num_waiting_reqs: int
    num_tokens: int


@dataclass
class WatchLoadUpdateReq:
    loads: List[GetLoadReqOutput]