io_struct.py 50.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""
15
The definition of objects transferred between different
16
processes (TokenizerManager, DetokenizerManager, Scheduler).
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
"""

19
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import uuid
21
from abc import ABC
YAMY's avatar
YAMY committed
22
from dataclasses import dataclass, field
23
from enum import Enum
24
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
25

26
from sglang.srt.lora.lora_registry import LoRARef
27
from sglang.srt.managers.schedule_batch import BaseFinishReason
28
from sglang.srt.multimodal.mm_utils import has_valid_data
29
from sglang.srt.sampling.sampling_params import SamplingParams
30
from sglang.srt.utils import ImageData
31

32
# Handle serialization of Image for pydantic
33
34
35
36
if TYPE_CHECKING:
    from PIL.Image import Image
else:
    Image = Any
37

Lianmin Zheng's avatar
Lianmin Zheng committed
38

39
40
41
@dataclass
class BaseReq(ABC):
    rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
42
    http_worker_ipc: Optional[str] = field(default=None, kw_only=True)
43
44
45
46
47
48
49
50
51
52
53
54
55

    def regenerate_rid(self):
        """Generate a new request ID and return it."""
        if isinstance(self.rid, list):
            self.rid = [uuid.uuid4().hex for _ in range(len(self.rid))]
        else:
            self.rid = uuid.uuid4().hex
        return self.rid


@dataclass
class BaseBatchReq(ABC):
    rids: Optional[List[str]] = field(default=None, kw_only=True)
56
    http_worker_ipcs: Optional[List[str]] = field(default=None, kw_only=True)
57
58
59
60
61
62
63

    def regenerate_rids(self):
        """Generate new request IDs and return them."""
        self.rids = [uuid.uuid4().hex for _ in range(len(self.rids))]
        return self.rids


64
# Parameters for a session
65
66
67
@dataclass
class SessionParams:
    id: Optional[str] = None
68
    rid: Optional[str] = None
69
70
    offset: Optional[int] = None
    replace: Optional[bool] = None
71
    drop_previous_output: Optional[bool] = None
72
73


74
75
# Type definitions for multimodal input data
# Individual data item types for each modality
76
ImageDataInputItem = Union[Image, str, ImageData, Dict]
77
78
79
80
81
82
83
84
85
86
87
88
AudioDataInputItem = Union[str, Dict]
VideoDataInputItem = Union[str, Dict]
# Union type for any multimodal data item
MultimodalDataInputItem = Union[
    ImageDataInputItem, VideoDataInputItem, AudioDataInputItem
]
# Format types supporting single items, lists, or nested lists for batch processing
MultimodalDataInputFormat = Union[
    List[List[MultimodalDataInputItem]],
    List[MultimodalDataInputItem],
    MultimodalDataInputItem,
]
89
90


Lianmin Zheng's avatar
Lianmin Zheng committed
91
@dataclass
92
class GenerateReqInput(BaseReq):
Ying Sheng's avatar
Ying Sheng committed
93
    # The input prompt. It can be a single prompt or a batch of prompts.
94
    text: Optional[Union[List[str], str]] = None
Rin Intachuen's avatar
Rin Intachuen committed
95
    # The token ids for text; one can specify either text or input_ids
96
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Rin Intachuen's avatar
Rin Intachuen committed
97
98
    # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
99
100
101
102
103
104
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
105
    image_data: Optional[MultimodalDataInputFormat] = None
106
    # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
107
108
109
    video_data: Optional[MultimodalDataInputFormat] = None
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
    audio_data: Optional[MultimodalDataInputFormat] = None
110
    # The sampling_params. See descriptions below.
111
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
112
    # Whether to return logprobs.
113
    return_logprob: Optional[Union[List[bool], bool]] = None
114
    # If return logprobs, the start location in the prompt for returning logprobs.
115
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
116
    logprob_start_len: Optional[Union[List[int], int]] = None
117
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
118
    top_logprobs_num: Optional[Union[List[int], int]] = None
119
120
    # If return logprobs, the token ids to return logprob for.
    token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
121
    # Whether to detokenize tokens in text in the returned logprobs.
122
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
123
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
124
    stream: bool = False
125
126
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
Lianmin Zheng's avatar
Lianmin Zheng committed
127
128
    # Whether to return hidden states
    return_hidden_states: Union[List[bool], bool] = False
129

130
131
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
132
133
134
    # Session info for continual prompting
    session_params: Optional[Union[List[Dict], Dict]] = None

135
    # The path to the LoRA adaptors
136
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
137
138
    # The uid of LoRA adaptors, should be initialized by tokenizer manager
    lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
139

140
141
142
143
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
    custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
144

145
    # For disaggregated inference
146
    bootstrap_host: Optional[Union[List[str], str]] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
147
    bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
148
    bootstrap_room: Optional[Union[List[int], int]] = None
149
    bootstrap_pair_key: Optional[Union[List[str], str]] = None
150

151
152
153
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

154
155
156
    # For background responses (OpenAI responses API)
    background: bool = False

157
158
159
    # Conversation id used for tracking requests
    conversation_id: Optional[str] = None

160
161
162
    # Priority for the request
    priority: Optional[int] = None

163
164
165
166
167
    # Extra key for classifying the request (e.g. cache_salt)
    extra_key: Optional[Union[List[str], str]] = None

    # Whether to disallow logging for this request (e.g. due to ZDR)
    no_logs: bool = False
168

169
170
    # For custom metric labels
    custom_labels: Optional[Dict[str, str]] = None
171

172
173
174
    # (Internal) Whether to return bytes for image generation
    return_bytes: bool = False

175
176
177
    # Whether to return entropy
    return_entropy: bool = False

178
    def contains_mm_input(self) -> bool:
179
180
181
182
183
        return (
            has_valid_data(self.image_data)
            or has_valid_data(self.video_data)
            or has_valid_data(self.audio_data)
        )
184

185
    def normalize_batch_and_arguments(self):
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        """
        Normalize the batch size and arguments for the request.

        This method resolves various input formats and ensures all parameters
        are properly formatted as either single values or batches depending on the input.
        It also handles parallel sampling expansion and sets default values for
        unspecified parameters.

        Raises:
            ValueError: If inputs are not properly specified (e.g., none or all of
                       text, input_ids, input_embeds are provided)
        """
        self._validate_inputs()
        self._determine_batch_size()
        self._handle_parallel_sampling()

        if self.is_single:
            self._normalize_single_inputs()
        else:
            self._normalize_batch_inputs()

    def _validate_inputs(self):
        """Validate that the input configuration is valid."""
Rin Intachuen's avatar
Rin Intachuen committed
209
210
211
212
213
214
        if (
            self.text is None and self.input_ids is None and self.input_embeds is None
        ) or (
            self.text is not None
            and self.input_ids is not None
            and self.input_embeds is not None
215
        ):
Rin Intachuen's avatar
Rin Intachuen committed
216
217
218
            raise ValueError(
                "Either text, input_ids or input_embeds should be provided."
            )
219

220
221
    def _determine_batch_size(self):
        """Determine if this is a single example or a batch and the batch size."""
222
223
224
225
226
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
227
                self.is_single = False
228
                self.batch_size = len(self.text)
Rin Intachuen's avatar
Rin Intachuen committed
229
230
            self.input_embeds = None
        elif self.input_ids is not None:
Yinghai Lu's avatar
Yinghai Lu committed
231
232
            if len(self.input_ids) == 0:
                raise ValueError("input_ids cannot be empty.")
233
234
235
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
236
            else:
237
                self.is_single = False
238
                self.batch_size = len(self.input_ids)
Rin Intachuen's avatar
Rin Intachuen committed
239
240
241
242
243
244
            self.input_embeds = None
        else:
            if isinstance(self.input_embeds[0][0], float):
                self.is_single = True
                self.batch_size = 1
            else:
245
                self.is_single = False
Rin Intachuen's avatar
Rin Intachuen committed
246
                self.batch_size = len(self.input_embeds)
247

248
249
250
    def _handle_parallel_sampling(self):
        """Handle parallel sampling parameters and adjust batch size if needed."""
        # Determine parallel sample count
251
252
        if self.sampling_params is None:
            self.parallel_sample_num = 1
253
            return
254
        elif isinstance(self.sampling_params, dict):
255
256
257
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
258
259
260
261
262
            for sampling_params in self.sampling_params:
                if self.parallel_sample_num != sampling_params.get("n", 1):
                    raise ValueError(
                        "The parallel_sample_num should be the same for all samples in sample params."
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
263

264
        # If using parallel sampling with a single example, convert to batch
265
266
267
268
269
270
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]
271
272
            if self.input_embeds is not None:
                self.input_embeds = [self.input_embeds]
273

274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    def _normalize_single_inputs(self):
        """Normalize inputs for a single example."""
        if self.sampling_params is None:
            self.sampling_params = {}
        if self.rid is None:
            self.rid = uuid.uuid4().hex
        if self.return_logprob is None:
            self.return_logprob = False
        if self.logprob_start_len is None:
            self.logprob_start_len = -1
        if self.top_logprobs_num is None:
            self.top_logprobs_num = 0
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = None

    def _normalize_batch_inputs(self):
        """Normalize inputs for a batch of examples, including parallel sampling expansion."""
        # Calculate expanded batch size
        if self.parallel_sample_num == 1:
            num = self.batch_size
Lianmin Zheng's avatar
Lianmin Zheng committed
294
        else:
295
296
297
298
299
            # Expand parallel_sample_num
            num = self.batch_size * self.parallel_sample_num

        # Expand input based on type
        self._expand_inputs(num)
300
        self._normalize_rid(num)
301
302
        self._normalize_lora_paths(num)
        self._normalize_image_data(num)
303
        self._normalize_video_data(num)
304
305
306
307
        self._normalize_audio_data(num)
        self._normalize_sampling_params(num)
        self._normalize_logprob_params(num)
        self._normalize_custom_logit_processor(num)
308
        self._normalize_bootstrap_params(num)
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335

    def _expand_inputs(self, num):
        """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
        if self.text is not None:
            if not isinstance(self.text, list):
                raise ValueError("Text should be a list for batch processing.")
            self.text = self.text * self.parallel_sample_num
        elif self.input_ids is not None:
            if not isinstance(self.input_ids, list) or not isinstance(
                self.input_ids[0], list
            ):
                raise ValueError(
                    "input_ids should be a list of lists for batch processing."
                )
            self.input_ids = self.input_ids * self.parallel_sample_num
        elif self.input_embeds is not None:
            if not isinstance(self.input_embeds, list):
                raise ValueError("input_embeds should be a list for batch processing.")
            self.input_embeds = self.input_embeds * self.parallel_sample_num

    def _normalize_lora_paths(self, num):
        """Normalize LoRA paths for batch processing."""
        if self.lora_path is not None:
            if isinstance(self.lora_path, str):
                self.lora_path = [self.lora_path] * num
            elif isinstance(self.lora_path, list):
                self.lora_path = self.lora_path * self.parallel_sample_num
336
            else:
337
338
339
340
341
342
343
344
345
346
347
                raise ValueError("lora_path should be a list or a string.")

    def _normalize_image_data(self, num):
        """Normalize image data for batch processing."""
        if self.image_data is None:
            self.image_data = [None] * num
        elif not isinstance(self.image_data, list):
            # Single image, convert to list of single-image lists
            self.image_data = [[self.image_data]] * num
            self.modalities = ["image"] * num
        elif isinstance(self.image_data, list):
348
349
350
351
352
            # Handle empty list case - treat as no images
            if len(self.image_data) == 0:
                self.image_data = [None] * num
                return

353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
            if len(self.image_data) != self.batch_size:
                raise ValueError(
                    "The length of image_data should be equal to the batch size."
                )

            self.modalities = []
            if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
                # Already a list of lists, keep as is
                for i in range(len(self.image_data)):
                    if self.image_data[i] is None or self.image_data[i] == [None]:
                        self.modalities.append(None)
                    elif len(self.image_data[i]) == 1:
                        self.modalities.append("image")
                    elif len(self.image_data[i]) > 1:
                        self.modalities.append("multi-images")
368
369
370
                    else:
                        # Ensure len(self.modalities) == len(self.image_data)
                        self.modalities.append(None)
371
                # Expand parallel_sample_num
372
373
                self.image_data = self.image_data * self.parallel_sample_num
                self.modalities = self.modalities * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
374
            else:
375
376
377
378
379
380
                # List of images for a batch, wrap each in a list
                wrapped_images = [[img] for img in self.image_data]
                # Expand for parallel sampling
                self.image_data = wrapped_images * self.parallel_sample_num
                self.modalities = ["image"] * num

381
382
383
384
385
386
387
388
389
    def _normalize_video_data(self, num):
        """Normalize video data for batch processing."""
        if self.video_data is None:
            self.video_data = [None] * num
        elif not isinstance(self.video_data, list):
            self.video_data = [self.video_data] * num
        elif isinstance(self.video_data, list):
            self.video_data = self.video_data * self.parallel_sample_num

390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    def _normalize_audio_data(self, num):
        """Normalize audio data for batch processing."""
        if self.audio_data is None:
            self.audio_data = [None] * num
        elif not isinstance(self.audio_data, list):
            self.audio_data = [self.audio_data] * num
        elif isinstance(self.audio_data, list):
            self.audio_data = self.audio_data * self.parallel_sample_num

    def _normalize_sampling_params(self, num):
        """Normalize sampling parameters for batch processing."""
        if self.sampling_params is None:
            self.sampling_params = [{}] * num
        elif isinstance(self.sampling_params, dict):
            self.sampling_params = [self.sampling_params] * num
        else:  # Already a list
            self.sampling_params = self.sampling_params * self.parallel_sample_num

    def _normalize_rid(self, num):
        """Normalize request IDs for batch processing."""
        if self.rid is None:
            self.rid = [uuid.uuid4().hex for _ in range(num)]
412
413
414
415
        elif isinstance(self.rid, str):
            new_rids = [f"{self.rid}_{i}" for i in range(num)]
            self.rid = new_rids
        elif isinstance(self.rid, list):
416
417
418
            # Note: the length of rid shall be the same as the batch_size,
            # as the rid would be expanded for parallel sampling in tokenizer_manager
            if len(self.rid) != self.batch_size:
419
420
421
422
423
                raise ValueError(
                    "The specified rids length mismatch with the batch_size for batch processing."
                )
        else:
            raise ValueError("The rid should be a string or a list of strings.")
424
425
426
427
428
429
430
431
432
433

    def _normalize_logprob_params(self, num):
        """Normalize logprob-related parameters for batch processing."""

        # Helper function to normalize a parameter
        def normalize_param(param, default_value, param_name):
            if param is None:
                return [default_value] * num
            elif not isinstance(param, list):
                return [param] * num
434
            else:
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
                if self.parallel_sample_num > 1:
                    raise ValueError(
                        f"Cannot use list {param_name} with parallel_sample_num > 1"
                    )
                return param

        # Normalize each logprob parameter
        self.return_logprob = normalize_param(
            self.return_logprob, False, "return_logprob"
        )
        self.logprob_start_len = normalize_param(
            self.logprob_start_len, -1, "logprob_start_len"
        )
        self.top_logprobs_num = normalize_param(
            self.top_logprobs_num, 0, "top_logprobs_num"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
451

452
453
454
455
456
457
458
459
460
461
462
463
464
        # Handle token_ids_logprob specially due to its nested structure
        if not self.token_ids_logprob:  # covers both None and []
            self.token_ids_logprob = [None] * num
        elif not isinstance(self.token_ids_logprob, list):
            self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
        elif not isinstance(self.token_ids_logprob[0], list):
            self.token_ids_logprob = [
                copy.deepcopy(self.token_ids_logprob) for _ in range(num)
            ]
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list token_ids_logprob with parallel_sample_num > 1"
            )
465

466
467
468
469
470
471
472
473
474
475
    def _normalize_custom_logit_processor(self, num):
        """Normalize custom logit processor for batch processing."""
        if self.custom_logit_processor is None:
            self.custom_logit_processor = [None] * num
        elif not isinstance(self.custom_logit_processor, list):
            self.custom_logit_processor = [self.custom_logit_processor] * num
        elif self.parallel_sample_num > 1:
            raise ValueError(
                "Cannot use list custom_logit_processor with parallel_sample_num > 1"
            )
476

477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    def _normalize_bootstrap_params(self, num):
        """Normalize bootstrap parameters for batch processing."""
        # Normalize bootstrap_host
        if self.bootstrap_host is None:
            self.bootstrap_host = [None] * num
        elif not isinstance(self.bootstrap_host, list):
            self.bootstrap_host = [self.bootstrap_host] * num
        elif isinstance(self.bootstrap_host, list):
            self.bootstrap_host = self.bootstrap_host * self.parallel_sample_num

        # Normalize bootstrap_port
        if self.bootstrap_port is None:
            self.bootstrap_port = [None] * num
        elif not isinstance(self.bootstrap_port, list):
            self.bootstrap_port = [self.bootstrap_port] * num
        elif isinstance(self.bootstrap_port, list):
            self.bootstrap_port = self.bootstrap_port * self.parallel_sample_num

        # Normalize bootstrap_room
        if self.bootstrap_room is None:
            self.bootstrap_room = [None] * num
        elif not isinstance(self.bootstrap_room, list):
            self.bootstrap_room = [self.bootstrap_room + i for i in range(num)]
        elif isinstance(self.bootstrap_room, list):
            self.bootstrap_room = self.bootstrap_room * self.parallel_sample_num

        # Normalize bootstrap_pair_key
        if self.bootstrap_pair_key is None:
            self.bootstrap_pair_key = [None] * num
        elif not isinstance(self.bootstrap_pair_key, list):
            self.bootstrap_pair_key = [self.bootstrap_pair_key] * num
        elif isinstance(self.bootstrap_pair_key, list):
            self.bootstrap_pair_key = self.bootstrap_pair_key * self.parallel_sample_num

511
512
    def _validate_session_params(self):
        """Validate that session parameters are properly formatted."""
513
        if self.session_params is not None:
514
            if not isinstance(self.session_params, dict) and not isinstance(
515
                self.session_params[0], dict
516
517
            ):
                raise ValueError("Session params must be a dict or a list of dicts.")
518

519
520
521
522
    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
523
524
525
            input_embeds=(
                self.input_embeds[i] if self.input_embeds is not None else None
            ),
526
            image_data=self.image_data[i],
527
            video_data=self.video_data[i],
Mick's avatar
Mick committed
528
            audio_data=self.audio_data[i],
529
530
531
532
533
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
534
            token_ids_logprob=self.token_ids_logprob[i],
535
536
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
537
            log_metrics=self.log_metrics,
538
539
540
541
542
            return_hidden_states=(
                self.return_hidden_states[i]
                if isinstance(self.return_hidden_states, list)
                else self.return_hidden_states
            ),
543
            modalities=self.modalities[i] if self.modalities else None,
544
            session_params=self.session_params,
545
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
546
            lora_id=self.lora_id[i] if self.lora_id is not None else None,
547
548
549
550
551
            custom_logit_processor=(
                self.custom_logit_processor[i]
                if self.custom_logit_processor is not None
                else None
            ),
552
            # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
553
554
555
            bootstrap_host=(
                self.bootstrap_host[i] if self.bootstrap_host is not None else None
            ),
556
557
558
            bootstrap_port=(
                self.bootstrap_port[i] if self.bootstrap_port is not None else None
            ),
559
560
561
            bootstrap_room=(
                self.bootstrap_room[i] if self.bootstrap_room is not None else None
            ),
562
563
564
565
566
            bootstrap_pair_key=(
                self.bootstrap_pair_key[i]
                if self.bootstrap_pair_key is not None
                else None
            ),
567
568
569
            data_parallel_rank=(
                self.data_parallel_rank if self.data_parallel_rank is not None else None
            ),
570
            conversation_id=self.conversation_id,
571
            priority=self.priority,
572
573
574
            extra_key=self.extra_key,
            no_logs=self.no_logs,
            custom_labels=self.custom_labels,
575
            return_bytes=self.return_bytes,
576
            return_entropy=self.return_entropy,
577
            http_worker_ipc=self.http_worker_ipc,
578
579
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
580
581

@dataclass
582
class TokenizedGenerateReqInput(BaseReq):
583
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
584
    input_text: str
585
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
586
    input_ids: List[int]
Mick's avatar
Mick committed
587
588
    # The multimodal inputs
    mm_inputs: dict
589
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
590
    sampling_params: SamplingParams
591
    # Whether to return the logprobs
592
    return_logprob: bool
593
    # If return logprobs, the start location in the prompt for returning logprobs.
594
    logprob_start_len: int
595
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
596
    top_logprobs_num: int
597
598
    # If return logprobs, the token id to return logprob for
    token_ids_logprob: List[int]
599
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
600
    stream: bool
601

602
603
    # Whether to return hidden states
    return_hidden_states: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
604

Rin Intachuen's avatar
Rin Intachuen committed
605
606
    # The input embeds
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
607

608
609
    # Session info for continual prompting
    session_params: Optional[SessionParams] = None
610

611
612
613
    # LoRA related
    lora_id: Optional[str] = None  # None means just use the base model

614
615
616
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
617
618
    custom_logit_processor: Optional[str] = None

619
620
    # For disaggregated inference
    bootstrap_host: Optional[str] = None
621
    bootstrap_port: Optional[int] = None
622
    bootstrap_room: Optional[int] = None
623
    bootstrap_pair_key: Optional[str] = None
624

625
626
627
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None

628
629
630
    # Priority for the request
    priority: Optional[int] = None

631
632
633
    # Extra key for classifying the request (e.g. cache_salt)
    extra_key: Optional[str] = None

634
635
    # Whether to disallow logging for this request (e.g. due to ZDR)
    no_logs: bool = False
636

637
638
639
    # tracing context
    trace_context: Optional[Dict] = None

640
641
642
    # (Internal) Whether to return bytes for image generation
    return_bytes: bool = False

643
644
645
    # Whether to return entropy
    return_entropy: bool = False

Lianmin Zheng's avatar
Lianmin Zheng committed
646

647
@dataclass
648
class BatchTokenizedGenerateReqInput(BaseBatchReq):
649
650
651
652
653
654
655
656
657
658
659
660
661
    # The batch of tokenized requests
    batch: List[TokenizedGenerateReqInput]

    def __len__(self):
        return len(self.batch)

    def __getitem__(self, i):
        return self.batch[i]

    def __iter__(self):
        return iter(self.batch)


662
@dataclass
663
class EmbeddingReqInput(BaseReq):
664
    # The input prompt. It can be a single prompt or a batch of prompts.
woodx's avatar
woodx committed
665
    text: Optional[Union[List[List[str]], List[str], str]] = None
666
667
668
669
670
671
    # The image input. It can be an image instance, file name, URL, or base64 encoded string.
    # Can be formatted as:
    # - Single image for a single request
    # - List of images (one per request in a batch)
    # - List of lists of images (multiple images per request)
    # See also python/sglang/srt/utils.py:load_image for more details.
672
    image_data: Optional[MultimodalDataInputFormat] = None
673
    # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
674
    video_data: Optional[MultimodalDataInputFormat] = None
675
    # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
676
    audio_data: Optional[MultimodalDataInputFormat] = None
677
678
679
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # Dummy sampling params for compatibility
680
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Rin Intachuen's avatar
Rin Intachuen committed
681
682
    # Dummy input embeds for compatibility
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
683
684
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
685
686
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
woodx's avatar
woodx committed
687
688
    # For cross-encoder requests
    is_cross_encoder_request: bool = False
689
690
    # Priority for the request
    priority: Optional[int] = None
691

692
693
694
    # For background responses (OpenAI responses API)
    background: bool = False

695
696
697
    # tracing context
    trace_context: Optional[Dict] = None

698
699
700
    # The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings.
    dimensions: Optional[int] = None

701
    def normalize_batch_and_arguments(self):
702
703
704
705
706
707
708
709
710
        # at least one of text, input_ids, or image should be provided
        if self.text is None and self.input_ids is None and self.image_data is None:
            raise ValueError(
                "At least one of text, input_ids, or image should be provided"
            )

        # text and input_ids cannot be provided at the same time
        if self.text is not None and self.input_ids is not None:
            raise ValueError("text and input_ids cannot be provided at the same time")
711

712
        # Derive the batch size
713
714
715
716
        self.batch_size = 0
        self.is_single = True

        # check the batch size of text
717
        if self.text is not None:
718
719
            if isinstance(self.text, list):
                self.batch_size += len(self.text)
720
                self.is_single = False
721
            else:
722
723
724
725
726
727
                self.batch_size += 1

        # check the batch size of input_ids
        if self.input_ids is not None:
            if isinstance(self.input_ids[0], list):
                self.batch_size += len(self.input_ids)
728
                self.is_single = False
729
            else:
730
731
                self.batch_size += 1

732
        # Fill in default arguments
733
        if self.is_single:
734
735
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
736
            if self.sampling_params is None:
737
                self.sampling_params = {}
738
            self.sampling_params["max_new_tokens"] = 0
739
740
741
742
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
743
744
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
745
            if self.sampling_params is None:
746
                self.sampling_params = [{}] * self.batch_size
747
748
            elif isinstance(self.sampling_params, dict):
                self.sampling_params = [self.sampling_params] * self.batch_size
749
            for i in range(self.batch_size):
750
                self.sampling_params[i]["max_new_tokens"] = 0
751

752
    def contains_mm_input(self) -> bool:
753
754
755
756
757
        return (
            has_valid_data(self.image_data)
            or has_valid_data(self.video_data)
            or has_valid_data(self.audio_data)
        )
758

759
    def __getitem__(self, i):
woodx's avatar
woodx committed
760
761
762
763
764
765
        if self.is_cross_encoder_request:
            return EmbeddingReqInput(
                text=[self.text[i]] if self.text is not None else None,
                sampling_params=self.sampling_params[i],
                rid=self.rid[i],
                is_cross_encoder_request=True,
766
                http_worker_ipc=self.http_worker_ipc,
woodx's avatar
woodx committed
767
768
            )

769
770
771
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
772
            image_data=self.image_data[i] if self.image_data is not None else None,
773
774
            audio_data=self.audio_data[i] if self.audio_data is not None else None,
            video_data=self.video_data[i] if self.video_data is not None else None,
775
776
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
777
            dimensions=self.dimensions,
778
            http_worker_ipc=self.http_worker_ipc,
779
        )
780
781
782


@dataclass
783
class TokenizedEmbeddingReqInput(BaseReq):
784
785
786
787
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
788
789
    # The image inputs
    image_inputs: dict
woodx's avatar
woodx committed
790
791
    # The token type ids
    token_type_ids: List[int]
792
793
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams
794
795
    # For data parallel rank routing
    data_parallel_rank: Optional[int] = None
796
797
    # Priority for the request
    priority: Optional[int] = None
798
799
    # The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings.
    dimensions: Optional[int] = None
800
801


802
@dataclass
803
class BatchTokenizedEmbeddingReqInput(BaseBatchReq):
804
805
806
807
808
809
810
811
812
813
814
815
816
    # The batch of tokenized embedding requests
    batch: List[TokenizedEmbeddingReqInput]

    def __len__(self):
        return len(self.batch)

    def __getitem__(self, i):
        return self.batch[i]

    def __iter__(self):
        return iter(self.batch)


Lianmin Zheng's avatar
Lianmin Zheng committed
817
@dataclass
818
class BatchTokenIDOutput(BaseBatchReq):
Lianmin Zheng's avatar
Lianmin Zheng committed
819
820
821
    # The finish reason
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
822
    decoded_texts: List[str]
823
824
    decode_ids: List[int]
    read_offsets: List[int]
825
    # Only used when `--skip-tokenizer-init` is on
826
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
827
    # Detokenization configs
Lianmin Zheng's avatar
Lianmin Zheng committed
828
    skip_special_tokens: List[bool]
829
    spaces_between_special_tokens: List[bool]
830
    no_stop_trim: List[bool]
831

Lianmin Zheng's avatar
Lianmin Zheng committed
832
833
834
835
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
836
    spec_verify_ct: List[int]
837
    spec_accepted_tokens: List[int]
838

Lianmin Zheng's avatar
Lianmin Zheng committed
839
840
841
842
843
844
845
846
847
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
848
849
850
851
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
852
    output_token_entropy_val: List[float]
Lianmin Zheng's avatar
Lianmin Zheng committed
853

854
    # Hidden states
855
856
    output_hidden_states: List[List[float]]

857
858
859
860
861
862
    # The information of placeholder tokens (e.g., image token)
    # idx is the index of the token in the prompt after expansion.
    # val is the length of padded tokens after expansion.
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]

863
864
865
    # Number of times each request was retracted.
    retraction_counts: List[int]

866
867
868
    # The trainer step id. Used to know which step's weights are used for sampling.
    token_steps: List[List[int]] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
869

870
@dataclass
871
class BatchMultimodalDecodeReq(BaseBatchReq):
872
873
874
875
876
877
878
879
880
881
882
    decoded_ids: List[int]
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    read_offsets: List[int]
    skip_special_tokens: List[bool]
    spaces_between_special_tokens: List[bool]
    image_resolutions: List[List[int]]
    resize_image_resolutions: List[List[int]]

883
884
885
886
887
888
    finished_reasons: List[BaseFinishReason]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
889

890
891
892
    # The information of placeholder tokens (e.g., image token)
    # idx is the index of the token in the prompt after expansion.
    # val is the length of padded tokens after expansion.
893
894
895
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]

896
897
    return_bytes: List[bool]

898
899
    # The trainer step id. Used to know which step's weights are used for sampling.
    token_steps: List[List[int]] = None
900

901

Lianmin Zheng's avatar
Lianmin Zheng committed
902
@dataclass
903
class BatchStrOutput(BaseBatchReq):
Lianmin Zheng's avatar
Lianmin Zheng committed
904
905
    # The finish reason
    finished_reasons: List[dict]
906
    # The output decoded strings
907
    output_strs: List[str]
908
909
    # The token ids
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
910
911
912
913
914

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
915
    spec_verify_ct: List[int]
916
    spec_accepted_tokens: List[int]
917

Lianmin Zheng's avatar
Lianmin Zheng committed
918
919
920
921
922
923
924
925
926
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
927
928
929
930
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
931
    output_token_entropy_val: List[float]
Liangsheng Yin's avatar
Liangsheng Yin committed
932

933
    # Hidden states
934
935
    output_hidden_states: List[List[float]]

936
937
938
    # The information of placeholder tokens (e.g., image token)
    # idx is the index of the token in the prompt after expansion.
    # val is the length of padded tokens after expansion.
939
940
941
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]

942
943
944
    # Number of times each request was retracted.
    retraction_counts: List[int]

945
946
947
    # The trainer step id. Used to know which step's weights are used for sampling.
    token_steps: List[List[int]] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
948

949
@dataclass
950
class BatchMultimodalOutput(BaseBatchReq):
951
952
    # The finish reason
    finished_reasons: List[dict]
953
    decoded_ids: List[List[int]]
954
    # The outputs
955
956
957
958
959
960
961
    outputs: Union[List[str | bytes], List[List[Dict]]]

    # probability values for input tokens and output tokens
    input_token_logprobs_val: List[List[float]]
    input_token_logprobs_idx: List[List[int]]
    output_token_logprobs_val: List[List[float]]
    output_token_logprobs_idx: List[List[int]]
962
963
964
965
966

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
967

968
969
970
971
972
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]

    return_bytes: List[bool]

973

974
@dataclass
975
class BatchEmbeddingOutput(BaseBatchReq):
Lianmin Zheng's avatar
Lianmin Zheng committed
976
977
    # The finish reason
    finished_reasons: List[BaseFinishReason]
978
    # The output embedding
979
    embeddings: Union[List[List[float]], List[Dict[int, float]]]
Lianmin Zheng's avatar
Lianmin Zheng committed
980
981
    # Token counts
    prompt_tokens: List[int]
982
    cached_tokens: List[int]
983
984
985
    # Placeholder token info
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]
986

987
988
989
    # Number of times each request was retracted.
    retraction_counts: List[int]

990

991
@dataclass
992
class ClearHiCacheReqInput(BaseReq):
993
994
995
996
    pass


@dataclass
997
class ClearHiCacheReqOutput(BaseReq):
998
999
1000
    success: bool


Liangsheng Yin's avatar
Liangsheng Yin committed
1001
@dataclass
1002
class FlushCacheReqInput(BaseReq):
1003
    pass
Cody Yu's avatar
Cody Yu committed
1004

1005

1006
@dataclass
1007
class FlushCacheReqOutput(BaseReq):
1008
1009
1010
    success: bool


1011
@dataclass
1012
class UpdateWeightFromDiskReqInput(BaseReq):
1013
1014
1015
1016
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None
1017
1018
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
1019
1020
    # Optional: Update weight version along with weights
    weight_version: Optional[str] = None
1021
1022
1023
1024
1025
1026
    # Whether to update weights asynchronously
    is_async: bool = False
    # Whether to empty torch cache
    torch_empty_cache: bool = False
    # Whether to keep the scheduler paused after weight update
    keep_pause: bool = False
1027
1028
    # Whether to recapture cuda graph after weight udpdate
    recapture_cuda_graph: bool = False
1029
1030
    # The trainer step id. Used to know which step's weights are used for sampling.
    token_step: int = 0
1031
1032
1033


@dataclass
1034
class UpdateWeightFromDiskReqOutput(BaseReq):
1035
1036
    success: bool
    message: str
1037
1038
    # Number of paused requests during weight sync.
    num_paused_requests: Optional[int] = 0
1039
1040


1041
@dataclass
1042
class UpdateWeightsFromDistributedReqInput(BaseReq):
1043
1044
1045
1046
1047
1048
1049
    names: List[str]
    dtypes: List[str]
    shapes: List[List[int]]
    # The group name
    group_name: str = "weight_update_group"
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
1050
1051
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
1052
1053
    # Optional: Update weight version along with weights
    weight_version: Optional[str] = None
1054
1055
1056


@dataclass
1057
class UpdateWeightsFromDistributedReqOutput(BaseReq):
1058
1059
1060
1061
    success: bool
    message: str


1062
@dataclass
1063
class UpdateWeightsFromTensorReqInput(BaseReq):
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
    """Update model weights from tensor input.

    - Tensors are serialized for transmission
    - Data is structured in JSON for easy transmission over HTTP
    """

    serialized_named_tensors: List[Union[str, bytes]]
    # Optional format specification for loading
    load_format: Optional[str] = None
    # Whether to flush the cache after updating weights
    flush_cache: bool = True
1075
1076
    # Whether to abort all requests before updating weights
    abort_all_requests: bool = False
1077
1078
    # Optional: Update weight version along with weights
    weight_version: Optional[str] = None
1079
1080
1081


@dataclass
1082
class UpdateWeightsFromTensorReqOutput(BaseReq):
1083
1084
1085
1086
    success: bool
    message: str


1087
@dataclass
1088
class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq):
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
    # The master address
    master_address: str
    # The ports for each rank's communication group
    ports: str
    # The rank in the communication group
    group_rank: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_send_group"
    # The backend
    backend: str = "nccl"


1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
# Now UpdateWeightsFromIPCReqInput and UpdateWeightsFromIPCReqOutput
# are only used by Checkpoint Engine (https://github.com/MoonshotAI/checkpoint-engine)
@dataclass
class UpdateWeightsFromIPCReqInput(BaseReq):
    # ZMQ socket paths for each device UUID
    zmq_handles: Dict[str, str]
    # Whether to flush cache after weight update
    flush_cache: bool = True
    # Optional: Update weight version along with weights
    weight_version: Optional[str] = None


@dataclass
class UpdateWeightsFromIPCReqOutput(BaseReq):
    success: bool
    message: str


1121
@dataclass
1122
class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
1123
1124
1125
1126
1127
    success: bool
    message: str


@dataclass
1128
class SendWeightsToRemoteInstanceReqInput(BaseReq):
1129
1130
1131
1132
1133
1134
1135
1136
1137
    # The master address
    master_address: str
    # The ports for each rank's communication group
    ports: str
    # The group name
    group_name: str = "weight_send_group"


@dataclass
1138
class SendWeightsToRemoteInstanceReqOutput(BaseReq):
1139
1140
1141
1142
    success: bool
    message: str


1143
@dataclass
1144
class InitWeightsUpdateGroupReqInput(BaseReq):
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
    # The master address
    master_address: str
    # The master port
    master_port: int
    # The rank offset
    rank_offset: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_update_group"
    # The backend
    backend: str = "nccl"


@dataclass
1160
class InitWeightsUpdateGroupReqOutput(BaseReq):
1161
1162
1163
1164
    success: bool
    message: str


1165
@dataclass
1166
class DestroyWeightsUpdateGroupReqInput(BaseReq):
1167
1168
1169
1170
    group_name: str = "weight_update_group"


@dataclass
1171
class DestroyWeightsUpdateGroupReqOutput(BaseReq):
1172
1173
1174
1175
    success: bool
    message: str


1176
@dataclass
1177
class UpdateWeightVersionReqInput(BaseReq):
1178
1179
1180
1181
1182
1183
    # The new weight version
    new_version: str
    # Whether to abort all running requests before updating
    abort_all_requests: bool = True


1184
@dataclass
1185
class GetWeightsByNameReqInput(BaseReq):
1186
1187
1188
1189
1190
    name: str
    truncate_size: int = 100


@dataclass
1191
class GetWeightsByNameReqOutput(BaseReq):
1192
1193
1194
    parameter: list


1195
@dataclass
1196
class ReleaseMemoryOccupationReqInput(BaseReq):
1197
1198
1199
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
1200
1201
1202


@dataclass
1203
class ReleaseMemoryOccupationReqOutput(BaseReq):
1204
    pass
1205
1206
1207


@dataclass
1208
class ResumeMemoryOccupationReqInput(BaseReq):
1209
1210
1211
    # Optional tags to identify the memory region, which is primarily used for RL
    # Currently we only support `weights` and `kv_cache`
    tags: Optional[List[str]] = None
1212
1213
1214


@dataclass
1215
class ResumeMemoryOccupationReqOutput(BaseReq):
1216
    pass
1217
1218


1219
@dataclass
1220
class SlowDownReqInput(BaseReq):
1221
1222
1223
1224
    forward_sleep_time: Optional[float]


@dataclass
1225
class SlowDownReqOutput(BaseReq):
1226
    pass
1227
1228


1229
@dataclass
1230
class AbortReq(BaseReq):
1231
1232
    # Whether to abort all requests
    abort_all: bool = False
1233
    # The finished reason data
1234
    finished_reason: Optional[Dict[str, Any]] = None
1235
    abort_message: Optional[str] = None
1236
1237

    def __post_init__(self):
1238
1239
1240
        # FIXME: This is a hack to keep the same with the old code
        if self.rid is None:
            self.rid = ""
1241
1242


1243
@dataclass
1244
class GetInternalStateReq(BaseReq):
1245
    pass
1246
1247
1248


@dataclass
1249
class GetInternalStateReqOutput(BaseReq):
1250
1251
1252
1253
    internal_state: Dict[Any, Any]


@dataclass
1254
class SetInternalStateReq(BaseReq):
1255
1256
1257
1258
    server_args: Dict[str, Any]


@dataclass
1259
class SetInternalStateReqOutput(BaseReq):
1260
1261
1262
1263
1264
    updated: bool
    server_args: Dict[str, Any]


@dataclass
1265
class ProfileReqInput(BaseReq):
1266
1267
1268
1269
1270
    # The output directory
    output_dir: Optional[str] = None
    # If set, it profile as many as this number of steps.
    # If it is set, profiling is automatically stopped after this step, and
    # the caller doesn't need to run stop_profile.
1271
    start_step: Optional[int] = None
1272
    num_steps: Optional[int] = None
1273
1274
    activities: Optional[List[str]] = None
    profile_by_stage: bool = False
1275
1276
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
1277
1278
    # Merge profiles from all ranks into a single trace
    merge_profiles: bool = False
1279
1280
1281


class ProfileReqType(Enum):
1282
1283
    START_PROFILE = 1
    STOP_PROFILE = 2
1284
1285


1286
@dataclass
1287
class ProfileReq(BaseReq):
1288
1289
    type: ProfileReqType
    output_dir: Optional[str] = None
1290
    start_step: Optional[int] = None
1291
1292
    num_steps: Optional[int] = None
    activities: Optional[List[str]] = None
1293
    profile_by_stage: bool = False
1294
1295
    with_stack: Optional[bool] = None
    record_shapes: Optional[bool] = None
1296
    profile_id: Optional[str] = None
1297
1298
    # Merge profiles from all ranks into a single trace
    merge_profiles: bool = False
1299
1300
1301


@dataclass
1302
class ProfileReqOutput(BaseReq):
1303
1304
1305
1306
    success: bool
    message: str


1307
@dataclass
1308
class FreezeGCReq(BaseReq):
1309
1310
1311
    pass


1312
@dataclass
1313
class ConfigureLoggingReq(BaseReq):
1314
    log_requests: Optional[bool] = None
1315
    log_requests_level: Optional[int] = None
1316
1317
    dump_requests_folder: Optional[str] = None
    dump_requests_threshold: Optional[int] = None
1318
    crash_dump_folder: Optional[str] = None
1319
1320


1321
@dataclass
1322
class OpenSessionReqInput(BaseReq):
1323
    capacity_of_str_len: int
1324
    session_id: Optional[str] = None
1325
1326
1327


@dataclass
1328
class CloseSessionReqInput(BaseReq):
1329
1330
1331
1332
    session_id: str


@dataclass
1333
class OpenSessionReqOutput(BaseReq):
1334
1335
    session_id: Optional[str]
    success: bool
YAMY's avatar
YAMY committed
1336
1337


1338
@dataclass
1339
class HealthCheckOutput(BaseReq):
1340
1341
1342
    pass


1343
class ExpertDistributionReqType(Enum):
1344
1345
1346
1347
1348
    START_RECORD = 1
    STOP_RECORD = 2
    DUMP_RECORD = 3


1349
@dataclass
1350
1351
1352
1353
class ExpertDistributionReq(BaseReq):
    action: ExpertDistributionReqType


1354
@dataclass
1355
class ExpertDistributionReqOutput(BaseReq):
1356
    pass
1357
1358


YAMY's avatar
YAMY committed
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
@dataclass
class Function:
    description: Optional[str] = None
    name: Optional[str] = None
    parameters: Optional[object] = None


@dataclass
class Tool:
    function: Function
    type: Optional[str] = "function"


@dataclass
1373
class ParseFunctionCallReq(BaseReq):
YAMY's avatar
YAMY committed
1374
1375
1376
1377
1378
1379
1380
    text: str  # The text to parse.
    tools: List[Tool] = field(
        default_factory=list
    )  # A list of available function tools (name, parameters, etc.).
    tool_call_parser: Optional[str] = (
        None  # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
    )
1381
1382


Xihuai Wang's avatar
Xihuai Wang committed
1383
@dataclass
1384
class SeparateReasoningReqInput(BaseReq):
Xihuai Wang's avatar
Xihuai Wang committed
1385
1386
1387
1388
    text: str  # The text to parse.
    reasoning_parser: str  # Specify the parser type, e.g., "deepseek-r1".


1389
@dataclass
1390
class VertexGenerateReqInput(BaseReq):
1391
1392
    instances: List[dict]
    parameters: Optional[dict] = None
1393
1394
1395


@dataclass
1396
class RpcReqInput(BaseReq):
1397
1398
1399
1400
1401
    method: str
    parameters: Optional[Dict] = None


@dataclass
1402
class RpcReqOutput(BaseReq):
1403
1404
    success: bool
    message: str
1405
1406
1407


@dataclass
1408
class LoadLoRAAdapterReqInput(BaseReq):
1409
1410
1411
1412
    # The name of the lora module to newly loaded.
    lora_name: str
    # The path of loading.
    lora_path: str
1413
1414
    # Whether to pin the LoRA adapter in memory.
    pinned: bool = False
1415
1416
1417
1418
1419
1420
1421
1422
    # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
    lora_id: Optional[str] = None

    def to_ref(self) -> LoRARef:
        return LoRARef(
            lora_id=self.lora_id,
            lora_name=self.lora_name,
            lora_path=self.lora_path,
1423
            pinned=self.pinned,
1424
        )
1425
1426
1427


@dataclass
1428
class UnloadLoRAAdapterReqInput(BaseReq):
1429
1430
    # The name of lora module to unload.
    lora_name: str
1431
1432
1433
1434
1435
1436
1437
1438
    # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
    lora_id: Optional[str] = None

    def to_ref(self) -> LoRARef:
        return LoRARef(
            lora_id=self.lora_id,
            lora_name=self.lora_name,
        )
1439
1440
1441


@dataclass
1442
class LoRAUpdateOutput(BaseReq):
1443
1444
    success: bool
    error_message: Optional[str] = None
1445
    loaded_adapters: Optional[Dict[str, LoRARef]] = None
1446
1447


1448
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput
fzyzcjy's avatar
fzyzcjy committed
1449
1450
1451
1452
1453
1454
1455
1456


class BlockReqType(Enum):
    BLOCK = 1
    UNBLOCK = 2


@dataclass
1457
class BlockReqInput(BaseReq):
fzyzcjy's avatar
fzyzcjy committed
1458
    type: BlockReqType
1459
1460
1461


@dataclass
1462
class GetLoadReqInput(BaseReq):
1463
1464
1465
1466
    pass


@dataclass
1467
class GetLoadReqOutput(BaseReq):
1468
1469
1470
1471
1472
1473
1474
    dp_rank: int
    num_reqs: int
    num_waiting_reqs: int
    num_tokens: int


@dataclass
1475
class WatchLoadUpdateReq(BaseReq):
1476
    loads: List[GetLoadReqOutput]
1477
1478


Lianmin Zheng's avatar
Lianmin Zheng committed
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
@dataclass
class SetInjectDumpMetadataReqInput(BaseReq):
    dump_metadata: Dict[str, Any]


@dataclass
class SetInjectDumpMetadataReqOutput(BaseReq):
    success: bool


1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
@dataclass
class LazyDumpTensorsReqInput(BaseReq):
    pass


@dataclass
class LazyDumpTensorsReqOutput(BaseReq):
    success: bool


1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
def _check_all_req_types():
    """A helper function to check all request types are defined in this file."""
    import inspect
    import sys

    all_classes = inspect.getmembers(sys.modules[__name__], inspect.isclass)
    for class_type in all_classes:
        # check its name
        name = class_type[0]
        is_io_struct = (
            name.endswith("Req") or name.endswith("Input") or name.endswith("Output")
        )
        is_base_req = issubclass(class_type[1], BaseReq) or issubclass(
            class_type[1], BaseBatchReq
        )
        if is_io_struct and not is_base_req:
            raise ValueError(f"{name} is not a subclass of BaseReq or BaseBatchReq.")
        if is_base_req and not is_io_struct:
            raise ValueError(
                f"{name} is a subclass of BaseReq but not follow the naming convention."
            )
1520
1521
1522


_check_all_req_types()