io_struct.py 21.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""
The definition of objects transfered between different
16
processes (TokenizerManager, DetokenizerManager, Controller).
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
"""

19
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import uuid
YAMY's avatar
YAMY committed
21
from dataclasses import dataclass, field
22
from enum import Enum
23
from typing import Any, Dict, List, Optional, Union
24

25
from sglang.srt.managers.schedule_batch import BaseFinishReason
26
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
27
28


29
30
31
32
33
34
35
36
@dataclass
class SessionParams:
    id: Optional[str] = None
    rid: Optional[str] = None
    offset: Optional[int] = None
    replace: Optional[bool] = None


Lianmin Zheng's avatar
Lianmin Zheng committed
37
38
@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
39
    # The input prompt. It can be a single prompt or a batch of prompts.
40
    text: Optional[Union[List[str], str]] = None
Rin Intachuen's avatar
Rin Intachuen committed
41
    # The token ids for text; one can specify either text or input_ids
42
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Rin Intachuen's avatar
Rin Intachuen committed
43
44
    # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
Ying Sheng's avatar
Ying Sheng committed
45
46
    # The image input. It can be a file name, a url, or base64 encoded string.
    # See also python/sglang/srt/utils.py:load_image.
Lianmin Zheng's avatar
Lianmin Zheng committed
47
    image_data: Optional[Union[List[str], str]] = None
48
    # The sampling_params. See descriptions below.
49
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
50
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
51
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
52
    # Whether to return logprobs.
53
    return_logprob: Optional[Union[List[bool], bool]] = None
54
    # If return logprobs, the start location in the prompt for returning logprobs.
55
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
56
    logprob_start_len: Optional[Union[List[int], int]] = None
57
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
58
    top_logprobs_num: Optional[Union[List[int], int]] = None
59
60
    # If return logprobs, the token ids to return logprob for.
    token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
61
    # Whether to detokenize tokens in text in the returned logprobs.
62
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
63
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
64
    stream: bool = False
65
66
67
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True

68
69
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
70
71
72
    # LoRA related
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

73
74
    # Session info for continual prompting
    session_params: Optional[Union[List[Dict], Dict]] = None
75

76
77
78
79
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
    custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
80

81
82
83
    # Whether to return hidden states
    return_hidden_states: bool = False

84
    def normalize_batch_and_arguments(self):
Rin Intachuen's avatar
Rin Intachuen committed
85
86
87
88
89
90
        if (
            self.text is None and self.input_ids is None and self.input_embeds is None
        ) or (
            self.text is not None
            and self.input_ids is not None
            and self.input_embeds is not None
91
        ):
Rin Intachuen's avatar
Rin Intachuen committed
92
93
94
            raise ValueError(
                "Either text, input_ids or input_embeds should be provided."
            )
95

96
        # Derive the batch size
97
98
99
100
101
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
102
                self.is_single = False
103
                self.batch_size = len(self.text)
Rin Intachuen's avatar
Rin Intachuen committed
104
105
            self.input_embeds = None
        elif self.input_ids is not None:
106
107
108
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
109
            else:
110
                self.is_single = False
111
                self.batch_size = len(self.input_ids)
Rin Intachuen's avatar
Rin Intachuen committed
112
113
114
115
116
117
118
            self.input_embeds = None
        else:
            if isinstance(self.input_embeds[0][0], float):
                self.is_single = True
                self.batch_size = 1
            else:
                self.batch_size = len(self.input_embeds)
119

120
121
        # Handle parallel sampling
        # When parallel sampling is used, we always treat the input as a batch.
122
123
        if self.sampling_params is None:
            self.parallel_sample_num = 1
124
        elif isinstance(self.sampling_params, dict):
125
126
127
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
Chayenne's avatar
Chayenne committed
128
129
130
131
            assert all(
                self.parallel_sample_num == sampling_params.get("n", 1)
                for sampling_params in self.sampling_params
            ), "The parallel_sample_num should be the same for all samples in sample params."
Lianmin Zheng's avatar
Lianmin Zheng committed
132

133
134
135
136
137
138
139
140
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]

        # Fill in default arguments
141
        if self.is_single:
Lianmin Zheng's avatar
Lianmin Zheng committed
142
143
144
145
            if self.sampling_params is None:
                self.sampling_params = {}
            if self.rid is None:
                self.rid = uuid.uuid4().hex
146
147
148
            if self.return_logprob is None:
                self.return_logprob = False
            if self.logprob_start_len is None:
149
                self.logprob_start_len = -1
Liangsheng Yin's avatar
Liangsheng Yin committed
150
151
            if self.top_logprobs_num is None:
                self.top_logprobs_num = 0
152
153
            if not self.token_ids_logprob:  # covers both None and []
                self.token_ids_logprob = None
Lianmin Zheng's avatar
Lianmin Zheng committed
154
        else:
155
156
            if self.parallel_sample_num == 1:
                num = self.batch_size
157
            else:
158
159
                # Expand parallel_sample_num
                num = self.batch_size * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
160

Lianmin Zheng's avatar
Lianmin Zheng committed
161
            if not self.image_data:
Lianmin Zheng's avatar
Lianmin Zheng committed
162
163
164
                self.image_data = [None] * num
            elif not isinstance(self.image_data, list):
                self.image_data = [self.image_data] * num
165
            elif isinstance(self.image_data, list):
166
                pass
Lianmin Zheng's avatar
Lianmin Zheng committed
167
168
169
170
171
172
173
174
175

            if self.sampling_params is None:
                self.sampling_params = [{}] * num
            elif not isinstance(self.sampling_params, list):
                self.sampling_params = [self.sampling_params] * num

            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(num)]
            else:
176
                assert isinstance(self.rid, list), "The rid should be a list."
Lianmin Zheng's avatar
Lianmin Zheng committed
177

178
179
180
181
            if self.return_logprob is None:
                self.return_logprob = [False] * num
            elif not isinstance(self.return_logprob, list):
                self.return_logprob = [self.return_logprob] * num
182
183
            else:
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
184

185
            if self.logprob_start_len is None:
186
                self.logprob_start_len = [-1] * num
187
188
            elif not isinstance(self.logprob_start_len, list):
                self.logprob_start_len = [self.logprob_start_len] * num
189
190
            else:
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
191

Liangsheng Yin's avatar
Liangsheng Yin committed
192
193
194
195
            if self.top_logprobs_num is None:
                self.top_logprobs_num = [0] * num
            elif not isinstance(self.top_logprobs_num, list):
                self.top_logprobs_num = [self.top_logprobs_num] * num
196
197
            else:
                assert self.parallel_sample_num == 1
Liangsheng Yin's avatar
Liangsheng Yin committed
198

199
200
201
202
203
204
205
206
207
208
209
            if not self.token_ids_logprob:  # covers both None and []
                self.token_ids_logprob = [None] * num
            elif not isinstance(self.token_ids_logprob, list):
                self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
            elif not isinstance(self.token_ids_logprob[0], list):
                self.token_ids_logprob = [
                    copy.deepcopy(self.token_ids_logprob) for _ in range(num)
                ]
            else:
                assert self.parallel_sample_num == 1

210
211
212
213
214
215
216
            if self.custom_logit_processor is None:
                self.custom_logit_processor = [None] * num
            elif not isinstance(self.custom_logit_processor, list):
                self.custom_logit_processor = [self.custom_logit_processor] * num
            else:
                assert self.parallel_sample_num == 1

217
218
219
220
221
222
        # Other checks
        if self.session_params is not None:
            assert isinstance(self.session_params, dict) or isinstance(
                self.session_params[0], dict
            )

223
224
225
226
227
228
229
230
231
232
233
234
235
236
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid

    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
            image_data=self.image_data[i],
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
237
            token_ids_logprob=self.token_ids_logprob[i],
238
239
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
240
            log_metrics=self.log_metrics,
241
242
            modalities=self.modalities[i] if self.modalities else None,
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
243
244
245
246
247
            custom_logit_processor=(
                self.custom_logit_processor[i]
                if self.custom_logit_processor is not None
                else None
            ),
248
            return_hidden_states=self.return_hidden_states,
249
250
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
251
252
253

@dataclass
class TokenizedGenerateReqInput:
254
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
255
    rid: str
256
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
257
    input_text: str
258
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
259
    input_ids: List[int]
260
    # The image inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
261
    image_inputs: dict
262
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
263
    sampling_params: SamplingParams
264
    # Whether to return the logprobs
265
    return_logprob: bool
266
    # If return logprobs, the start location in the prompt for returning logprobs.
267
    logprob_start_len: int
268
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
269
    top_logprobs_num: int
270
271
    # If return logprobs, the token id to return logprob for
    token_ids_logprob: List[int]
272
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
273
274
    stream: bool

275
276
    # LoRA related
    lora_path: Optional[str] = None  # None means just use the base model
Rin Intachuen's avatar
Rin Intachuen committed
277
278
    # The input embeds
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
279

280
281
    # Session info for continual prompting
    session_params: Optional[SessionParams] = None
282

283
284
285
    # Custom logit processor for advanced sampling control. Must be a serialized instance
    # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
    # Use the processor's `to_str()` method to generate the serialized string.
286
287
    custom_logit_processor: Optional[str] = None

288
289
290
    # Whether to return hidden states
    return_hidden_states: bool = False

Lianmin Zheng's avatar
Lianmin Zheng committed
291

292
293
294
295
296
297
298
299
300
301
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
    text: Optional[Union[List[str], str]] = None
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
    sampling_params: Union[List[Dict], Dict] = None
Rin Intachuen's avatar
Rin Intachuen committed
302
303
    # Dummy input embeds for compatibility
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
304
305
    # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
    log_metrics: bool = True
306

307
    def normalize_batch_and_arguments(self):
308
309
310
311
312
        if (self.text is None and self.input_ids is None) or (
            self.text is not None and self.input_ids is not None
        ):
            raise ValueError("Either text or input_ids should be provided.")

313
        # Derive the batch size
314
        if self.text is not None:
315
316
317
318
319
320
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
                self.is_single = False
                self.batch_size = len(self.text)
321
        else:
322
323
324
325
326
327
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
            else:
                self.is_single = False
                self.batch_size = len(self.input_ids)
328

329
        # Fill in default arguments
330
        if self.is_single:
331
332
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
333
            if self.sampling_params is None:
334
                self.sampling_params = {}
335
            self.sampling_params["max_new_tokens"] = 0
336
337
338
339
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
340
341
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
342
            if self.sampling_params is None:
343
344
                self.sampling_params = [{}] * self.batch_size
            for i in range(self.batch_size):
345
                self.sampling_params[i]["max_new_tokens"] = 0
346

347
348
349
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid
350

351
352
353
354
355
356
357
    def __getitem__(self, i):
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
        )
358
359
360


@dataclass
361
class TokenizedEmbeddingReqInput:
362
363
364
365
366
367
368
369
370
371
    # The request id
    rid: str
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams


Lianmin Zheng's avatar
Lianmin Zheng committed
372
373
@dataclass
class BatchTokenIDOut:
374
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
375
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
376
377
378
    # The finish reason
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
379
    decoded_texts: List[str]
380
381
    decode_ids: List[int]
    read_offsets: List[int]
382
    # Only used when `--skip-tokenizer-init` is on
383
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
384
    # Detokenization configs
Lianmin Zheng's avatar
Lianmin Zheng committed
385
    skip_special_tokens: List[bool]
386
    spaces_between_special_tokens: List[bool]
387
    no_stop_trim: List[bool]
388

Lianmin Zheng's avatar
Lianmin Zheng committed
389
390
391
392
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
393
394
    spec_verify_ct: List[int]

Lianmin Zheng's avatar
Lianmin Zheng committed
395
396
397
398
399
400
401
402
403
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
404
405
406
407
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Lianmin Zheng's avatar
Lianmin Zheng committed
408

409
    # Hidden states
410
411
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
412

413
414
415
416
@dataclass
class BatchMultimodalDecodeReq:
    # The request id
    rids: List[str]
417
418
419
420
421
422
    finished_reasons: List[BaseFinishReason]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
423
424


Lianmin Zheng's avatar
Lianmin Zheng committed
425
426
@dataclass
class BatchStrOut:
427
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
428
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
429
430
    # The finish reason
    finished_reasons: List[dict]
431
    # The output decoded strings
432
    output_strs: List[str]
433
434
    # The token ids
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
435
436
437
438
439

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
440
    spec_verify_ct: List[int]
441

Lianmin Zheng's avatar
Lianmin Zheng committed
442
443
444
445
446
447
448
449
450
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
451
452
453
454
    input_token_ids_logprobs_val: List[List]
    input_token_ids_logprobs_idx: List[List]
    output_token_ids_logprobs_val: List[List]
    output_token_ids_logprobs_idx: List[List]
Liangsheng Yin's avatar
Liangsheng Yin committed
455

456
    # Hidden states
457
458
    output_hidden_states: List[List[float]]

Liangsheng Yin's avatar
Liangsheng Yin committed
459

460
461
462
463
@dataclass
class BatchMultimodalOut:
    # The request id
    rids: List[str]
464
465
466
467
468
469
470
471
472
    # The finish reason
    finished_reasons: List[dict]
    # The outputs
    outputs: List[List[Dict]]

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
473
474


475
476
@dataclass
class BatchEmbeddingOut:
477
    # The request id
478
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
479
480
    # The finish reason
    finished_reasons: List[BaseFinishReason]
481
    # The output embedding
482
    embeddings: List[List[float]]
Lianmin Zheng's avatar
Lianmin Zheng committed
483
484
    # Token counts
    prompt_tokens: List[int]
485
486


Liangsheng Yin's avatar
Liangsheng Yin committed
487
488
489
@dataclass
class FlushCacheReq:
    pass
Cody Yu's avatar
Cody Yu committed
490

491

492
@dataclass
Chayenne's avatar
Chayenne committed
493
class UpdateWeightFromDiskReqInput:
494
495
496
497
498
499
500
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None


@dataclass
Chayenne's avatar
Chayenne committed
501
class UpdateWeightFromDiskReqOutput:
502
503
    success: bool
    message: str
504
505
    # Number of paused requests during weight sync.
    num_paused_requests: Optional[int] = 0
506
507


508
509
510
511
512
513
514
515
516
517
518
519
520
@dataclass
class UpdateWeightsFromDistributedReqInput:
    name: str
    dtype: str
    shape: List[int]


@dataclass
class UpdateWeightsFromDistributedReqOutput:
    success: bool
    message: str


521
522
@dataclass
class UpdateWeightsFromTensorReqInput:
523
    serialized_named_tensors: bytes  # indeed Dict[str, torch.Tensor]
524
525
    load_format: Optional[str]
    flush_cache: bool
526
527
528
529
530
531
532
533


@dataclass
class UpdateWeightsFromTensorReqOutput:
    success: bool
    message: str


534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
@dataclass
class InitWeightsUpdateGroupReqInput:
    # The master address
    master_address: str
    # The master port
    master_port: int
    # The rank offset
    rank_offset: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_update_group"
    # The backend
    backend: str = "nccl"


@dataclass
class InitWeightsUpdateGroupReqOutput:
    success: bool
    message: str


556
557
558
559
560
561
562
563
564
565
566
@dataclass
class GetWeightsByNameReqInput:
    name: str
    truncate_size: int = 100


@dataclass
class GetWeightsByNameReqOutput:
    parameter: list


567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
@dataclass
class ReleaseMemoryOccupationReqInput:
    pass


@dataclass
class ReleaseMemoryOccupationReqOutput:
    pass


@dataclass
class ResumeMemoryOccupationReqInput:
    pass


@dataclass
class ResumeMemoryOccupationReqOutput:
    pass


587
588
@dataclass
class AbortReq:
589
    # The request id
590
    rid: str
591
592


593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
@dataclass
class GetInternalStateReq:
    pass


@dataclass
class GetInternalStateReqOutput:
    internal_state: Dict[Any, Any]


@dataclass
class SetInternalStateReq:
    server_args: Dict[str, Any]


@dataclass
class SetInternalStateReqOutput:
    updated: bool
    server_args: Dict[str, Any]


@dataclass
class ProfileReqInput:
    # The output directory
    output_dir: Optional[str] = None
    # If set, it profile as many as this number of steps.
    # If it is set, profiling is automatically stopped after this step, and
    # the caller doesn't need to run stop_profile.
    num_steps: Optional[int] = None
    activities: Optional[List[str]] = None


class ProfileReqType(Enum):
626
627
    START_PROFILE = 1
    STOP_PROFILE = 2
628
629


630
631
632
633
634
635
636
637
638
639
640
641
642
643
@dataclass
class ProfileReq:
    type: ProfileReqType
    output_dir: Optional[str] = None
    num_steps: Optional[int] = None
    activities: Optional[List[str]] = None


@dataclass
class ProfileReqOutput:
    success: bool
    message: str


644
645
646
@dataclass
class ConfigureLoggingReq:
    log_requests: Optional[bool] = None
647
    log_requests_level: Optional[int] = None
648
649
650
651
    dump_requests_folder: Optional[str] = None
    dump_requests_threshold: Optional[int] = None


652
653
654
@dataclass
class OpenSessionReqInput:
    capacity_of_str_len: int
655
    session_id: Optional[str] = None
656
657
658
659
660
661
662
663
664


@dataclass
class CloseSessionReqInput:
    session_id: str


@dataclass
class OpenSessionReqOutput:
665
666
    session_id: Optional[str]
    success: bool
YAMY's avatar
YAMY committed
667
668


669
670
671
672
673
@dataclass
class HealthCheckOutput:
    pass


YAMY's avatar
YAMY committed
674
675
676
677
678
679
680
681
682
683
684
685
686
687
@dataclass
class Function:
    description: Optional[str] = None
    name: Optional[str] = None
    parameters: Optional[object] = None


@dataclass
class Tool:
    function: Function
    type: Optional[str] = "function"


@dataclass
688
class ParseFunctionCallReq:
YAMY's avatar
YAMY committed
689
690
691
692
693
694
695
    text: str  # The text to parse.
    tools: List[Tool] = field(
        default_factory=list
    )  # A list of available function tools (name, parameters, etc.).
    tool_call_parser: Optional[str] = (
        None  # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
    )
696
697


Xihuai Wang's avatar
Xihuai Wang committed
698
699
700
701
702
703
@dataclass
class SeparateReasoningReqInput:
    text: str  # The text to parse.
    reasoning_parser: str  # Specify the parser type, e.g., "deepseek-r1".


704
705
706
707
@dataclass
class VertexGenerateReqInput:
    instances: List[dict]
    parameters: Optional[dict] = None