io_struct.py 23.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""
The definition of objects transfered between different
16
processes (TokenizerManager, DetokenizerManager, Controller).
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
"""

19
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import uuid
YAMY's avatar
YAMY committed
21
from dataclasses import dataclass, field
22
from enum import Enum
23
from typing import Any, Dict, List, Optional, Union
24

25
from sglang.srt.managers.schedule_batch import BaseFinishReason
26
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
27
28


29
30
31
32
33
34
35
36
@dataclass
class SessionParams:
    id: Optional[str] = None
    rid: Optional[str] = None
    offset: Optional[int] = None
    replace: Optional[bool] = None


Lianmin Zheng's avatar
Lianmin Zheng committed
37
38
@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
39
    # The input prompt. It can be a single prompt or a batch of prompts.
40
    text: Optional[Union[List[str], str]] = None
Rin Intachuen's avatar
Rin Intachuen committed
41
    # The token ids for text; one can specify either text or input_ids
42
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Rin Intachuen's avatar
Rin Intachuen committed
43
44
    # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
Ying Sheng's avatar
Ying Sheng committed
45
46
    # The image input. It can be a file name, a url, or base64 encoded string.
    # See also python/sglang/srt/utils.py:load_image.
Lianmin Zheng's avatar
Lianmin Zheng committed
47
    image_data: Optional[Union[List[str], str]] = None
Mick's avatar
Mick committed
48
49
    # The audio input. Like image data, tt can be a file name, a url, or base64 encoded string.
    audio_data: Optional[Union[List[str], str]] = None
50
    # The sampling_params. See descriptions below.
51
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
52
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
53
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
54
    # Whether to return logprobs.
55
    return_logprob: Optional[Union[List[bool], bool]] = None
56
    # If return logprobs, the start location in the prompt for returning logprobs.
57
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
58
    logprob_start_len: Optional[Union[List[int], int]] = None
59
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
60
    top_logprobs_num: Optional[Union[List[int], int]] = None
61
62
    # If return logprobs, the token ids to return logprob for.
    token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
63
    # Whether to detokenize tokens in text in the returned logprobs.
64
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
65
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
66
    stream: bool = False
67
68
69
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True

70
71
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
72
73
74
    # LoRA related
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

75
76
    # Session info for continual prompting
    session_params: Optional[Union[List[Dict], Dict]] = None
77

78
79
80
81
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
    custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
82

83
84
85
    # Whether to return hidden states
    return_hidden_states: bool = False

86
    def normalize_batch_and_arguments(self):
Rin Intachuen's avatar
Rin Intachuen committed
87
88
89
90
91
92
        if (
            self.text is None and self.input_ids is None and self.input_embeds is None
        ) or (
            self.text is not None
            and self.input_ids is not None
            and self.input_embeds is not None
93
        ):
Rin Intachuen's avatar
Rin Intachuen committed
94
95
96
            raise ValueError(
                "Either text, input_ids or input_embeds should be provided."
            )
97

98
        # Derive the batch size
99
100
101
102
103
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
104
                self.is_single = False
105
                self.batch_size = len(self.text)
Rin Intachuen's avatar
Rin Intachuen committed
106
107
            self.input_embeds = None
        elif self.input_ids is not None:
Yinghai Lu's avatar
Yinghai Lu committed
108
109
            if len(self.input_ids) == 0:
                raise ValueError("input_ids cannot be empty.")
110
111
112
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
113
            else:
114
                self.is_single = False
115
                self.batch_size = len(self.input_ids)
Rin Intachuen's avatar
Rin Intachuen committed
116
117
118
119
120
121
122
            self.input_embeds = None
        else:
            if isinstance(self.input_embeds[0][0], float):
                self.is_single = True
                self.batch_size = 1
            else:
                self.batch_size = len(self.input_embeds)
123

124
125
        # Handle parallel sampling
        # When parallel sampling is used, we always treat the input as a batch.
126
127
        if self.sampling_params is None:
            self.parallel_sample_num = 1
128
        elif isinstance(self.sampling_params, dict):
129
130
131
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
Chayenne's avatar
Chayenne committed
132
133
134
135
            assert all(
                self.parallel_sample_num == sampling_params.get("n", 1)
                for sampling_params in self.sampling_params
            ), "The parallel_sample_num should be the same for all samples in sample params."
Lianmin Zheng's avatar
Lianmin Zheng committed
136

137
138
139
140
141
142
143
144
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]

        # Fill in default arguments
145
        if self.is_single:
Lianmin Zheng's avatar
Lianmin Zheng committed
146
147
148
149
            if self.sampling_params is None:
                self.sampling_params = {}
            if self.rid is None:
                self.rid = uuid.uuid4().hex
150
151
152
            if self.return_logprob is None:
                self.return_logprob = False
            if self.logprob_start_len is None:
153
                self.logprob_start_len = -1
Liangsheng Yin's avatar
Liangsheng Yin committed
154
155
            if self.top_logprobs_num is None:
                self.top_logprobs_num = 0
156
157
            if not self.token_ids_logprob:  # covers both None and []
                self.token_ids_logprob = None
Lianmin Zheng's avatar
Lianmin Zheng committed
158
        else:
159
160
            if self.parallel_sample_num == 1:
                num = self.batch_size
161
            else:
162
163
                # Expand parallel_sample_num
                num = self.batch_size * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
164

Lianmin Zheng's avatar
Lianmin Zheng committed
165
            if not self.image_data:
Lianmin Zheng's avatar
Lianmin Zheng committed
166
167
168
                self.image_data = [None] * num
            elif not isinstance(self.image_data, list):
                self.image_data = [self.image_data] * num
169
            elif isinstance(self.image_data, list):
170
                pass
Lianmin Zheng's avatar
Lianmin Zheng committed
171

Mick's avatar
Mick committed
172
173
174
175
176
177
178
            if self.audio_data is None:
                self.audio_data = [None] * num
            elif not isinstance(self.audio_data, list):
                self.audio_data = [self.audio_data] * num
            elif isinstance(self.audio_data, list):
                pass

Lianmin Zheng's avatar
Lianmin Zheng committed
179
180
181
182
183
184
185
186
            if self.sampling_params is None:
                self.sampling_params = [{}] * num
            elif not isinstance(self.sampling_params, list):
                self.sampling_params = [self.sampling_params] * num

            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(num)]
            else:
187
                assert isinstance(self.rid, list), "The rid should be a list."
Lianmin Zheng's avatar
Lianmin Zheng committed
188

189
190
191
192
            if self.return_logprob is None:
                self.return_logprob = [False] * num
            elif not isinstance(self.return_logprob, list):
                self.return_logprob = [self.return_logprob] * num
193
194
            else:
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
195

196
            if self.logprob_start_len is None:
197
                self.logprob_start_len = [-1] * num
198
199
            elif not isinstance(self.logprob_start_len, list):
                self.logprob_start_len = [self.logprob_start_len] * num
200
201
            else:
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
202

Liangsheng Yin's avatar
Liangsheng Yin committed
203
204
205
206
            if self.top_logprobs_num is None:
                self.top_logprobs_num = [0] * num
            elif not isinstance(self.top_logprobs_num, list):
                self.top_logprobs_num = [self.top_logprobs_num] * num
207
208
            else:
                assert self.parallel_sample_num == 1
Liangsheng Yin's avatar
Liangsheng Yin committed
209

210
211
212
213
214
215
216
217
218
219
220
            if not self.token_ids_logprob:  # covers both None and []
                self.token_ids_logprob = [None] * num
            elif not isinstance(self.token_ids_logprob, list):
                self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
            elif not isinstance(self.token_ids_logprob[0], list):
                self.token_ids_logprob = [
                    copy.deepcopy(self.token_ids_logprob) for _ in range(num)
                ]
            else:
                assert self.parallel_sample_num == 1

221
222
223
224
225
226
227
            if self.custom_logit_processor is None:
                self.custom_logit_processor = [None] * num
            elif not isinstance(self.custom_logit_processor, list):
                self.custom_logit_processor = [self.custom_logit_processor] * num
            else:
                assert self.parallel_sample_num == 1

228
229
230
231
232
233
        # Other checks
        if self.session_params is not None:
            assert isinstance(self.session_params, dict) or isinstance(
                self.session_params[0], dict
            )

234
235
236
237
238
239
240
241
242
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid

    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
            image_data=self.image_data[i],
Mick's avatar
Mick committed
243
            audio_data=self.audio_data[i],
244
245
246
247
248
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
249
            token_ids_logprob=self.token_ids_logprob[i],
250
251
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
252
            log_metrics=self.log_metrics,
253
254
            modalities=self.modalities[i] if self.modalities else None,
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
255
256
257
258
259
            custom_logit_processor=(
                self.custom_logit_processor[i]
                if self.custom_logit_processor is not None
                else None
            ),
260
            return_hidden_states=self.return_hidden_states,
261
262
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
263
264
265

@dataclass
class TokenizedGenerateReqInput:
266
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
267
    rid: str
268
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
269
    input_text: str
270
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
271
    input_ids: List[int]
Mick's avatar
Mick committed
272
273
    # The multimodal inputs
    mm_inputs: dict
274
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
275
    sampling_params: SamplingParams
276
    # Whether to return the logprobs
277
    return_logprob: bool
278
    # If return logprobs, the start location in the prompt for returning logprobs.
279
    logprob_start_len: int
280
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
281
    top_logprobs_num: int
282
283
    # If return logprobs, the token id to return logprob for
    token_ids_logprob: List[int]
284
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
285
286
    stream: bool

287
288
    # LoRA related
    lora_path: Optional[str] = None  # None means just use the base model
Rin Intachuen's avatar
Rin Intachuen committed
289
290
    # The input embeds
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
291

292
293
    # Session info for continual prompting
    session_params: Optional[SessionParams] = None
294

295
296
297
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
298
299
    custom_logit_processor: Optional[str] = None

300
301
302
    # Whether to return hidden states
    return_hidden_states: bool = False

Lianmin Zheng's avatar
Lianmin Zheng committed
303

304
305
306
307
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
    text: Optional[Union[List[str], str]] = None
308
309
    # The image input. It can be a file name, a url, or base64 encoded string.
    image_data: Optional[Union[List[str], str]] = None
310
311
312
313
314
315
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
    sampling_params: Union[List[Dict], Dict] = None
Rin Intachuen's avatar
Rin Intachuen committed
316
317
    # Dummy input embeds for compatibility
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
318
319
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
320
321
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
322

323
    def normalize_batch_and_arguments(self):
324
325
326
327
328
329
330
331
332
        # at least one of text, input_ids, or image should be provided
        if self.text is None and self.input_ids is None and self.image_data is None:
            raise ValueError(
                "At least one of text, input_ids, or image should be provided"
            )

        # text and input_ids cannot be provided at the same time
        if self.text is not None and self.input_ids is not None:
            raise ValueError("text and input_ids cannot be provided at the same time")
333

334
        # Derive the batch size
335
336
337
338
        self.batch_size = 0
        self.is_single = True

        # check the batch size of text
339
        if self.text is not None:
340
341
            if isinstance(self.text, list):
                self.batch_size += len(self.text)
342
            else:
343
344
345
346
347
348
                self.batch_size += 1

        # check the batch size of input_ids
        if self.input_ids is not None:
            if isinstance(self.input_ids[0], list):
                self.batch_size += len(self.input_ids)
349
            else:
350
351
352
353
                self.batch_size += 1

        if self.batch_size > 1:
            self.is_single = False
354

355
        # Fill in default arguments
356
        if self.is_single:
357
358
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
359
            if self.sampling_params is None:
360
                self.sampling_params = {}
361
            self.sampling_params["max_new_tokens"] = 0
362
363
364
365
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
366
367
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
368
            if self.sampling_params is None:
369
370
                self.sampling_params = [{}] * self.batch_size
            for i in range(self.batch_size):
371
                self.sampling_params[i]["max_new_tokens"] = 0
372

373
374
375
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid
376

377
378
379
380
    def __getitem__(self, i):
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
381
            image_data=self.image_data[i] if self.image_data is not None else None,
382
383
384
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
        )
385
386
387


@dataclass
388
class TokenizedEmbeddingReqInput:
389
390
391
392
393
394
    # The request id
    rid: str
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
395
396
    # The image inputs
    image_inputs: dict
397
398
399
400
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams


Lianmin Zheng's avatar
Lianmin Zheng committed
401
402
@dataclass
class BatchTokenIDOut:
403
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
404
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
405
406
407
    # The finish reason
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
408
    decoded_texts: List[str]
409
410
    decode_ids: List[int]
    read_offsets: List[int]
411
    # Only used when `--skip-tokenizer-init` is on
412
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
413
    # Detokenization configs
Lianmin Zheng's avatar
Lianmin Zheng committed
414
    skip_special_tokens: List[bool]
415
    spaces_between_special_tokens: List[bool]
416
    no_stop_trim: List[bool]
417

Lianmin Zheng's avatar
Lianmin Zheng committed
418
419
420
421
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
422
423
    spec_verify_ct: List[int]

Lianmin Zheng's avatar
Lianmin Zheng committed
424
425
426
427
428
429
430
431
432
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
433
434
435
436
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Lianmin Zheng's avatar
Lianmin Zheng committed
437

438
    # Hidden states
439
440
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
441

442
443
444
445
@dataclass
class BatchMultimodalDecodeReq:
    # The request id
    rids: List[str]
446
447
448
449
450
451
    finished_reasons: List[BaseFinishReason]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
452
453


Lianmin Zheng's avatar
Lianmin Zheng committed
454
455
@dataclass
class BatchStrOut:
456
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
457
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
458
459
    # The finish reason
    finished_reasons: List[dict]
460
    # The output decoded strings
461
    output_strs: List[str]
462
463
    # The token ids
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
464
465
466
467
468

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
469
    spec_verify_ct: List[int]
470

Lianmin Zheng's avatar
Lianmin Zheng committed
471
472
473
474
475
476
477
478
479
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
480
481
482
483
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Liangsheng Yin's avatar
Liangsheng Yin committed
484

485
    # Hidden states
486
487
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
488

489
490
491
492
@dataclass
class BatchMultimodalOut:
    # The request id
    rids: List[str]
493
494
495
496
497
498
499
500
501
    # The finish reason
    finished_reasons: List[dict]
    # The outputs
    outputs: List[List[Dict]]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
502
503


504
505
@dataclass
class BatchEmbeddingOut:
506
    # The request id
507
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
508
509
    # The finish reason
    finished_reasons: List[BaseFinishReason]
510
    # The output embedding
511
    embeddings: List[List[float]]
Lianmin Zheng's avatar
Lianmin Zheng committed
512
513
    # Token counts
    prompt_tokens: List[int]
514
    cached_tokens: List[int]
515
516


Liangsheng Yin's avatar
Liangsheng Yin committed
517
518
519
@dataclass
class FlushCacheReq:
    pass
Cody Yu's avatar
Cody Yu committed
520

521

522
@dataclass
Chayenne's avatar
Chayenne committed
523
class UpdateWeightFromDiskReqInput:
524
525
526
527
528
529
530
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None


@dataclass
Chayenne's avatar
Chayenne committed
531
class UpdateWeightFromDiskReqOutput:
532
533
    success: bool
    message: str
534
535
    # Number of paused requests during weight sync.
    num_paused_requests: Optional[int] = 0
536
537


538
539
540
541
542
543
544
545
546
547
548
549
550
@dataclass
class UpdateWeightsFromDistributedReqInput:
    name: str
    dtype: str
    shape: List[int]


@dataclass
class UpdateWeightsFromDistributedReqOutput:
    success: bool
    message: str


551
552
@dataclass
class UpdateWeightsFromTensorReqInput:
553
554
    # List containing one serialized Dict[str, torch.Tensor] per TP worker
    serialized_named_tensors: List[bytes]
555
556
    load_format: Optional[str]
    flush_cache: bool
557
558
559
560
561
562
563
564


@dataclass
class UpdateWeightsFromTensorReqOutput:
    success: bool
    message: str


565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
@dataclass
class InitWeightsUpdateGroupReqInput:
    # The master address
    master_address: str
    # The master port
    master_port: int
    # The rank offset
    rank_offset: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_update_group"
    # The backend
    backend: str = "nccl"


@dataclass
class InitWeightsUpdateGroupReqOutput:
    success: bool
    message: str


587
588
589
590
591
592
593
594
595
596
597
@dataclass
class GetWeightsByNameReqInput:
    name: str
    truncate_size: int = 100


@dataclass
class GetWeightsByNameReqOutput:
    parameter: list


598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
@dataclass
class ReleaseMemoryOccupationReqInput:
    pass


@dataclass
class ReleaseMemoryOccupationReqOutput:
    pass


@dataclass
class ResumeMemoryOccupationReqInput:
    pass


@dataclass
class ResumeMemoryOccupationReqOutput:
    pass


618
619
@dataclass
class AbortReq:
620
    # The request id
621
    rid: str
622
623


624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
@dataclass
class GetInternalStateReq:
    pass


@dataclass
class GetInternalStateReqOutput:
    internal_state: Dict[Any, Any]


@dataclass
class SetInternalStateReq:
    server_args: Dict[str, Any]


@dataclass
class SetInternalStateReqOutput:
    updated: bool
    server_args: Dict[str, Any]


@dataclass
class ProfileReqInput:
    # The output directory
    output_dir: Optional[str] = None
    # If set, it profile as many as this number of steps.
    # If it is set, profiling is automatically stopped after this step, and
    # the caller doesn't need to run stop_profile.
    num_steps: Optional[int] = None
    activities: Optional[List[str]] = None


class ProfileReqType(Enum):
657
658
    START_PROFILE = 1
    STOP_PROFILE = 2
659
660


661
662
663
664
665
666
class ExpertDistributionReq(Enum):
    START_RECORD = 1
    STOP_RECORD = 2
    DUMP_RECORD = 3


667
668
669
670
671
672
673
674
675
676
677
678
679
680
@dataclass
class ProfileReq:
    type: ProfileReqType
    output_dir: Optional[str] = None
    num_steps: Optional[int] = None
    activities: Optional[List[str]] = None


@dataclass
class ProfileReqOutput:
    success: bool
    message: str


681
682
683
@dataclass
class ConfigureLoggingReq:
    log_requests: Optional[bool] = None
684
    log_requests_level: Optional[int] = None
685
686
687
688
    dump_requests_folder: Optional[str] = None
    dump_requests_threshold: Optional[int] = None


689
690
691
@dataclass
class OpenSessionReqInput:
    capacity_of_str_len: int
692
    session_id: Optional[str] = None
693
694
695
696
697
698
699
700
701


@dataclass
class CloseSessionReqInput:
    session_id: str


@dataclass
class OpenSessionReqOutput:
702
703
    session_id: Optional[str]
    success: bool
YAMY's avatar
YAMY committed
704
705


706
707
708
709
710
@dataclass
class HealthCheckOutput:
    pass


YAMY's avatar
YAMY committed
711
712
713
714
715
716
717
718
719
720
721
722
723
724
@dataclass
class Function:
    description: Optional[str] = None
    name: Optional[str] = None
    parameters: Optional[object] = None


@dataclass
class Tool:
    function: Function
    type: Optional[str] = "function"


@dataclass
725
class ParseFunctionCallReq:
YAMY's avatar
YAMY committed
726
727
728
729
730
731
732
    text: str  # The text to parse.
    tools: List[Tool] = field(
        default_factory=list
    )  # A list of available function tools (name, parameters, etc.).
    tool_call_parser: Optional[str] = (
        None  # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
    )
733
734


Xihuai Wang's avatar
Xihuai Wang committed
735
736
737
738
739
740
@dataclass
class SeparateReasoningReqInput:
    text: str  # The text to parse.
    reasoning_parser: str  # Specify the parser type, e.g., "deepseek-r1".


741
742
743
744
@dataclass
class VertexGenerateReqInput:
    instances: List[dict]
    parameters: Optional[dict] = None
745
746
747
748
749
750
751
752
753
754
755
756


@dataclass
class RpcReqInput:
    method: str
    parameters: Optional[Dict] = None


@dataclass
class RpcReqOutput:
    success: bool
    message: str