io_struct.py 22.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""
The definition of objects transfered between different
16
processes (TokenizerManager, DetokenizerManager, Controller).
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
"""

19
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import uuid
YAMY's avatar
YAMY committed
21
from dataclasses import dataclass, field
22
from enum import Enum
23
from typing import Any, Dict, List, Optional, Union
24

25
from sglang.srt.managers.schedule_batch import BaseFinishReason
26
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
27
28


29
30
31
32
33
34
35
36
@dataclass
class SessionParams:
    id: Optional[str] = None
    rid: Optional[str] = None
    offset: Optional[int] = None
    replace: Optional[bool] = None


Lianmin Zheng's avatar
Lianmin Zheng committed
37
38
@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
39
    # The input prompt. It can be a single prompt or a batch of prompts.
40
    text: Optional[Union[List[str], str]] = None
Rin Intachuen's avatar
Rin Intachuen committed
41
    # The token ids for text; one can specify either text or input_ids
42
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Rin Intachuen's avatar
Rin Intachuen committed
43
44
    # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
Ying Sheng's avatar
Ying Sheng committed
45
46
    # The image input. It can be a file name, a url, or base64 encoded string.
    # See also python/sglang/srt/utils.py:load_image.
Lianmin Zheng's avatar
Lianmin Zheng committed
47
    image_data: Optional[Union[List[str], str]] = None
48
    # The sampling_params. See descriptions below.
49
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
50
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
51
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
52
    # Whether to return logprobs.
53
    return_logprob: Optional[Union[List[bool], bool]] = None
54
    # If return logprobs, the start location in the prompt for returning logprobs.
55
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
56
    logprob_start_len: Optional[Union[List[int], int]] = None
57
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
58
    top_logprobs_num: Optional[Union[List[int], int]] = None
59
60
    # If return logprobs, the token ids to return logprob for.
    token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
61
    # Whether to detokenize tokens in text in the returned logprobs.
62
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
63
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
64
    stream: bool = False
65
66
67
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True

68
69
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
70
71
72
    # LoRA related
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

73
74
    # Session info for continual prompting
    session_params: Optional[Union[List[Dict], Dict]] = None
75

76
77
78
79
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
    custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
80

81
82
83
    # Whether to return hidden states
    return_hidden_states: bool = False

84
    def normalize_batch_and_arguments(self):
Rin Intachuen's avatar
Rin Intachuen committed
85
86
87
88
89
90
        if (
            self.text is None and self.input_ids is None and self.input_embeds is None
        ) or (
            self.text is not None
            and self.input_ids is not None
            and self.input_embeds is not None
91
        ):
Rin Intachuen's avatar
Rin Intachuen committed
92
93
94
            raise ValueError(
                "Either text, input_ids or input_embeds should be provided."
            )
95

96
        # Derive the batch size
97
98
99
100
101
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
102
                self.is_single = False
103
                self.batch_size = len(self.text)
Rin Intachuen's avatar
Rin Intachuen committed
104
105
            self.input_embeds = None
        elif self.input_ids is not None:
106
107
108
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
109
            else:
110
                self.is_single = False
111
                self.batch_size = len(self.input_ids)
Rin Intachuen's avatar
Rin Intachuen committed
112
113
114
115
116
117
118
            self.input_embeds = None
        else:
            if isinstance(self.input_embeds[0][0], float):
                self.is_single = True
                self.batch_size = 1
            else:
                self.batch_size = len(self.input_embeds)
119

120
121
        # Handle parallel sampling
        # When parallel sampling is used, we always treat the input as a batch.
122
123
        if self.sampling_params is None:
            self.parallel_sample_num = 1
124
        elif isinstance(self.sampling_params, dict):
125
126
127
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
Chayenne's avatar
Chayenne committed
128
129
130
131
            assert all(
                self.parallel_sample_num == sampling_params.get("n", 1)
                for sampling_params in self.sampling_params
            ), "The parallel_sample_num should be the same for all samples in sample params."
Lianmin Zheng's avatar
Lianmin Zheng committed
132

133
134
135
136
137
138
139
140
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]

        # Fill in default arguments
141
        if self.is_single:
Lianmin Zheng's avatar
Lianmin Zheng committed
142
143
144
145
            if self.sampling_params is None:
                self.sampling_params = {}
            if self.rid is None:
                self.rid = uuid.uuid4().hex
146
147
148
            if self.return_logprob is None:
                self.return_logprob = False
            if self.logprob_start_len is None:
149
                self.logprob_start_len = -1
Liangsheng Yin's avatar
Liangsheng Yin committed
150
151
            if self.top_logprobs_num is None:
                self.top_logprobs_num = 0
152
153
            if not self.token_ids_logprob:  # covers both None and []
                self.token_ids_logprob = None
Lianmin Zheng's avatar
Lianmin Zheng committed
154
        else:
155
156
            if self.parallel_sample_num == 1:
                num = self.batch_size
157
            else:
158
159
                # Expand parallel_sample_num
                num = self.batch_size * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
160

Lianmin Zheng's avatar
Lianmin Zheng committed
161
            if not self.image_data:
Lianmin Zheng's avatar
Lianmin Zheng committed
162
163
164
                self.image_data = [None] * num
            elif not isinstance(self.image_data, list):
                self.image_data = [self.image_data] * num
165
            elif isinstance(self.image_data, list):
166
                pass
Lianmin Zheng's avatar
Lianmin Zheng committed
167
168
169
170
171
172
173
174
175

            if self.sampling_params is None:
                self.sampling_params = [{}] * num
            elif not isinstance(self.sampling_params, list):
                self.sampling_params = [self.sampling_params] * num

            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(num)]
            else:
176
                assert isinstance(self.rid, list), "The rid should be a list."
Lianmin Zheng's avatar
Lianmin Zheng committed
177

178
179
180
181
            if self.return_logprob is None:
                self.return_logprob = [False] * num
            elif not isinstance(self.return_logprob, list):
                self.return_logprob = [self.return_logprob] * num
182
183
            else:
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
184

185
            if self.logprob_start_len is None:
186
                self.logprob_start_len = [-1] * num
187
188
            elif not isinstance(self.logprob_start_len, list):
                self.logprob_start_len = [self.logprob_start_len] * num
189
190
            else:
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
191

Liangsheng Yin's avatar
Liangsheng Yin committed
192
193
194
195
            if self.top_logprobs_num is None:
                self.top_logprobs_num = [0] * num
            elif not isinstance(self.top_logprobs_num, list):
                self.top_logprobs_num = [self.top_logprobs_num] * num
196
197
            else:
                assert self.parallel_sample_num == 1
Liangsheng Yin's avatar
Liangsheng Yin committed
198

199
200
201
202
203
204
205
206
207
208
209
            if not self.token_ids_logprob:  # covers both None and []
                self.token_ids_logprob = [None] * num
            elif not isinstance(self.token_ids_logprob, list):
                self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
            elif not isinstance(self.token_ids_logprob[0], list):
                self.token_ids_logprob = [
                    copy.deepcopy(self.token_ids_logprob) for _ in range(num)
                ]
            else:
                assert self.parallel_sample_num == 1

210
211
212
213
214
215
216
            if self.custom_logit_processor is None:
                self.custom_logit_processor = [None] * num
            elif not isinstance(self.custom_logit_processor, list):
                self.custom_logit_processor = [self.custom_logit_processor] * num
            else:
                assert self.parallel_sample_num == 1

217
218
219
220
221
222
        # Other checks
        if self.session_params is not None:
            assert isinstance(self.session_params, dict) or isinstance(
                self.session_params[0], dict
            )

223
224
225
226
227
228
229
230
231
232
233
234
235
236
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid

    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
            image_data=self.image_data[i],
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
237
            token_ids_logprob=self.token_ids_logprob[i],
238
239
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
240
            log_metrics=self.log_metrics,
241
242
            modalities=self.modalities[i] if self.modalities else None,
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
243
244
245
246
247
            custom_logit_processor=(
                self.custom_logit_processor[i]
                if self.custom_logit_processor is not None
                else None
            ),
248
            return_hidden_states=self.return_hidden_states,
249
250
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
251
252
253

@dataclass
class TokenizedGenerateReqInput:
254
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
255
    rid: str
256
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
257
    input_text: str
258
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
259
    input_ids: List[int]
260
    # The image inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
261
    image_inputs: dict
262
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
263
    sampling_params: SamplingParams
264
    # Whether to return the logprobs
265
    return_logprob: bool
266
    # If return logprobs, the start location in the prompt for returning logprobs.
267
    logprob_start_len: int
268
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
269
    top_logprobs_num: int
270
271
    # If return logprobs, the token id to return logprob for
    token_ids_logprob: List[int]
272
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
273
274
    stream: bool

275
276
    # LoRA related
    lora_path: Optional[str] = None  # None means just use the base model
Rin Intachuen's avatar
Rin Intachuen committed
277
278
    # The input embeds
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
279

280
281
    # Session info for continual prompting
    session_params: Optional[SessionParams] = None
282

283
284
285
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
286
287
    custom_logit_processor: Optional[str] = None

288
289
290
    # Whether to return hidden states
    return_hidden_states: bool = False

Lianmin Zheng's avatar
Lianmin Zheng committed
291

292
293
294
295
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
    text: Optional[Union[List[str], str]] = None
296
297
    # The image input. It can be a file name, a url, or base64 encoded string.
    image_data: Optional[Union[List[str], str]] = None
298
299
300
301
302
303
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
    sampling_params: Union[List[Dict], Dict] = None
Rin Intachuen's avatar
Rin Intachuen committed
304
305
    # Dummy input embeds for compatibility
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
306
307
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
308
309
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
310

311
    def normalize_batch_and_arguments(self):
312
313
314
315
316
317
318
319
320
        # at least one of text, input_ids, or image should be provided
        if self.text is None and self.input_ids is None and self.image_data is None:
            raise ValueError(
                "At least one of text, input_ids, or image should be provided"
            )

        # text and input_ids cannot be provided at the same time
        if self.text is not None and self.input_ids is not None:
            raise ValueError("text and input_ids cannot be provided at the same time")
321

322
        # Derive the batch size
323
324
325
326
        self.batch_size = 0
        self.is_single = True

        # check the batch size of text
327
        if self.text is not None:
328
329
            if isinstance(self.text, list):
                self.batch_size += len(self.text)
330
            else:
331
332
333
334
335
336
                self.batch_size += 1

        # check the batch size of input_ids
        if self.input_ids is not None:
            if isinstance(self.input_ids[0], list):
                self.batch_size += len(self.input_ids)
337
            else:
338
339
340
341
                self.batch_size += 1

        if self.batch_size > 1:
            self.is_single = False
342

343
        # Fill in default arguments
344
        if self.is_single:
345
346
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
347
            if self.sampling_params is None:
348
                self.sampling_params = {}
349
            self.sampling_params["max_new_tokens"] = 0
350
351
352
353
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
354
355
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
356
            if self.sampling_params is None:
357
358
                self.sampling_params = [{}] * self.batch_size
            for i in range(self.batch_size):
359
                self.sampling_params[i]["max_new_tokens"] = 0
360

361
362
363
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid
364

365
366
367
368
    def __getitem__(self, i):
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
369
            image_data=self.image_data[i] if self.image_data is not None else None,
370
371
372
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
        )
373
374
375


@dataclass
376
class TokenizedEmbeddingReqInput:
377
378
379
380
381
382
    # The request id
    rid: str
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
383
384
    # The image inputs
    image_inputs: dict
385
386
387
388
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams


Lianmin Zheng's avatar
Lianmin Zheng committed
389
390
@dataclass
class BatchTokenIDOut:
391
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
392
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
393
394
395
    # The finish reason
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
396
    decoded_texts: List[str]
397
398
    decode_ids: List[int]
    read_offsets: List[int]
399
    # Only used when `--skip-tokenizer-init` is on
400
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
401
    # Detokenization configs
Lianmin Zheng's avatar
Lianmin Zheng committed
402
    skip_special_tokens: List[bool]
403
    spaces_between_special_tokens: List[bool]
404
    no_stop_trim: List[bool]
405

Lianmin Zheng's avatar
Lianmin Zheng committed
406
407
408
409
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
410
411
    spec_verify_ct: List[int]

Lianmin Zheng's avatar
Lianmin Zheng committed
412
413
414
415
416
417
418
419
420
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
421
422
423
424
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Lianmin Zheng's avatar
Lianmin Zheng committed
425

426
    # Hidden states
427
428
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
429

430
431
432
433
@dataclass
class BatchMultimodalDecodeReq:
    # The request id
    rids: List[str]
434
435
436
437
438
439
    finished_reasons: List[BaseFinishReason]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
440
441


Lianmin Zheng's avatar
Lianmin Zheng committed
442
443
@dataclass
class BatchStrOut:
444
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
445
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
446
447
    # The finish reason
    finished_reasons: List[dict]
448
    # The output decoded strings
449
    output_strs: List[str]
450
451
    # The token ids
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
452
453
454
455
456

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
457
    spec_verify_ct: List[int]
458

Lianmin Zheng's avatar
Lianmin Zheng committed
459
460
461
462
463
464
465
466
467
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
468
469
470
471
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Liangsheng Yin's avatar
Liangsheng Yin committed
472

473
    # Hidden states
474
475
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
476

477
478
479
480
@dataclass
class BatchMultimodalOut:
    # The request id
    rids: List[str]
481
482
483
484
485
486
487
488
489
    # The finish reason
    finished_reasons: List[dict]
    # The outputs
    outputs: List[List[Dict]]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
490
491


492
493
@dataclass
class BatchEmbeddingOut:
494
    # The request id
495
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
496
497
    # The finish reason
    finished_reasons: List[BaseFinishReason]
498
    # The output embedding
499
    embeddings: List[List[float]]
Lianmin Zheng's avatar
Lianmin Zheng committed
500
501
    # Token counts
    prompt_tokens: List[int]
502
    cached_tokens: List[int]
503
504


Liangsheng Yin's avatar
Liangsheng Yin committed
505
506
507
@dataclass
class FlushCacheReq:
    pass
Cody Yu's avatar
Cody Yu committed
508

509

510
@dataclass
Chayenne's avatar
Chayenne committed
511
class UpdateWeightFromDiskReqInput:
512
513
514
515
516
517
518
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None


@dataclass
Chayenne's avatar
Chayenne committed
519
class UpdateWeightFromDiskReqOutput:
520
521
    success: bool
    message: str
522
523
    # Number of paused requests during weight sync.
    num_paused_requests: Optional[int] = 0
524
525


526
527
528
529
530
531
532
533
534
535
536
537
538
@dataclass
class UpdateWeightsFromDistributedReqInput:
    name: str
    dtype: str
    shape: List[int]


@dataclass
class UpdateWeightsFromDistributedReqOutput:
    success: bool
    message: str


539
540
@dataclass
class UpdateWeightsFromTensorReqInput:
541
    serialized_named_tensors: bytes  # indeed Dict[str, torch.Tensor]
542
543
    load_format: Optional[str]
    flush_cache: bool
544
545
546
547
548
549
550
551


@dataclass
class UpdateWeightsFromTensorReqOutput:
    success: bool
    message: str


552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
@dataclass
class InitWeightsUpdateGroupReqInput:
    # The master address
    master_address: str
    # The master port
    master_port: int
    # The rank offset
    rank_offset: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_update_group"
    # The backend
    backend: str = "nccl"


@dataclass
class InitWeightsUpdateGroupReqOutput:
    success: bool
    message: str


574
575
576
577
578
579
580
581
582
583
584
@dataclass
class GetWeightsByNameReqInput:
    name: str
    truncate_size: int = 100


@dataclass
class GetWeightsByNameReqOutput:
    parameter: list


585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
@dataclass
class ReleaseMemoryOccupationReqInput:
    pass


@dataclass
class ReleaseMemoryOccupationReqOutput:
    pass


@dataclass
class ResumeMemoryOccupationReqInput:
    pass


@dataclass
class ResumeMemoryOccupationReqOutput:
    pass


605
606
@dataclass
class AbortReq:
607
    # The request id
608
    rid: str
609
610


611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
@dataclass
class GetInternalStateReq:
    pass


@dataclass
class GetInternalStateReqOutput:
    internal_state: Dict[Any, Any]


@dataclass
class SetInternalStateReq:
    server_args: Dict[str, Any]


@dataclass
class SetInternalStateReqOutput:
    updated: bool
    server_args: Dict[str, Any]


@dataclass
class ProfileReqInput:
    # The output directory
    output_dir: Optional[str] = None
    # If set, it profile as many as this number of steps.
    # If it is set, profiling is automatically stopped after this step, and
    # the caller doesn't need to run stop_profile.
    num_steps: Optional[int] = None
    activities: Optional[List[str]] = None


class ProfileReqType(Enum):
644
645
    START_PROFILE = 1
    STOP_PROFILE = 2
646
647


648
649
650
651
652
653
654
655
656
657
658
659
660
661
@dataclass
class ProfileReq:
    type: ProfileReqType
    output_dir: Optional[str] = None
    num_steps: Optional[int] = None
    activities: Optional[List[str]] = None


@dataclass
class ProfileReqOutput:
    success: bool
    message: str


662
663
664
@dataclass
class ConfigureLoggingReq:
    log_requests: Optional[bool] = None
665
    log_requests_level: Optional[int] = None
666
667
668
669
    dump_requests_folder: Optional[str] = None
    dump_requests_threshold: Optional[int] = None


670
671
672
@dataclass
class OpenSessionReqInput:
    capacity_of_str_len: int
673
    session_id: Optional[str] = None
674
675
676
677
678
679
680
681
682


@dataclass
class CloseSessionReqInput:
    session_id: str


@dataclass
class OpenSessionReqOutput:
683
684
    session_id: Optional[str]
    success: bool
YAMY's avatar
YAMY committed
685
686


687
688
689
690
691
@dataclass
class HealthCheckOutput:
    pass


YAMY's avatar
YAMY committed
692
693
694
695
696
697
698
699
700
701
702
703
704
705
@dataclass
class Function:
    description: Optional[str] = None
    name: Optional[str] = None
    parameters: Optional[object] = None


@dataclass
class Tool:
    function: Function
    type: Optional[str] = "function"


@dataclass
706
class ParseFunctionCallReq:
YAMY's avatar
YAMY committed
707
708
709
710
711
712
713
    text: str  # The text to parse.
    tools: List[Tool] = field(
        default_factory=list
    )  # A list of available function tools (name, parameters, etc.).
    tool_call_parser: Optional[str] = (
        None  # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
    )
714
715


Xihuai Wang's avatar
Xihuai Wang committed
716
717
718
719
720
721
@dataclass
class SeparateReasoningReqInput:
    text: str  # The text to parse.
    reasoning_parser: str  # Specify the parser type, e.g., "deepseek-r1".


722
723
724
725
@dataclass
class VertexGenerateReqInput:
    instances: List[dict]
    parameters: Optional[dict] = None
726
727
728
729
730
731
732
733
734
735
736
737


@dataclass
class RpcReqInput:
    method: str
    parameters: Optional[Dict] = None


@dataclass
class RpcReqOutput:
    success: bool
    message: str