Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3c699772
Unverified
Commit
3c699772
authored
Oct 03, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 03, 2025
Browse files
Introduce naming convention in `io_struct` and base sglang io classes. (#10133)
parent
e8100774
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
223 additions
and
189 deletions
+223
-189
python/sglang/srt/entrypoints/grpc_request_manager.py
python/sglang/srt/entrypoints/grpc_request_manager.py
+7
-7
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+3
-5
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+5
-5
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+10
-10
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+126
-102
python/sglang/srt/managers/multi_tokenizer_mixin.py
python/sglang/srt/managers/multi_tokenizer_mixin.py
+17
-17
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+11
-8
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+11
-7
python/sglang/srt/managers/tokenizer_communicator_mixin.py
python/sglang/srt/managers/tokenizer_communicator_mixin.py
+9
-5
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+24
-23
No files found.
python/sglang/srt/entrypoints/grpc_request_manager.py
View file @
3c699772
...
...
@@ -22,8 +22,8 @@ import zmq.asyncio
from
sglang.srt.managers.disagg_service
import
start_disagg_service
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
BatchEmbeddingOut
put
,
BatchTokenIDOut
put
,
HealthCheckOutput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
...
...
@@ -467,9 +467,9 @@ class GrpcRequestManager:
await
self
.
is_pause_cond
.
wait
()
# Handle different output types
if
isinstance
(
recv_obj
,
BatchTokenIDOut
):
if
isinstance
(
recv_obj
,
BatchTokenIDOut
put
):
await
self
.
_handle_batch_output
(
recv_obj
)
elif
isinstance
(
recv_obj
,
BatchEmbeddingOut
):
elif
isinstance
(
recv_obj
,
BatchEmbeddingOut
put
):
await
self
.
_handle_embedding_output
(
recv_obj
)
elif
isinstance
(
recv_obj
,
HealthCheckOutput
):
await
self
.
_handle_health_check_output
(
recv_obj
)
...
...
@@ -498,7 +498,7 @@ class GrpcRequestManager:
def
_convert_logprob_style
(
self
,
state
:
GrpcReqState
,
batch_out
:
BatchTokenIDOut
,
batch_out
:
BatchTokenIDOut
put
,
batch_index
:
int
,
):
"""
...
...
@@ -545,7 +545,7 @@ class GrpcRequestManager:
batch_out
.
output_top_logprobs_idx
[
batch_index
]
)
async
def
_handle_batch_output
(
self
,
batch_out
:
BatchTokenIDOut
):
async
def
_handle_batch_output
(
self
,
batch_out
:
BatchTokenIDOut
put
):
"""Handle batch generation output from scheduler."""
# Process each request in the batch
for
i
,
rid
in
enumerate
(
batch_out
.
rids
):
...
...
@@ -666,7 +666,7 @@ class GrpcRequestManager:
asyncio
.
create_task
(
cleanup
())
async
def
_handle_embedding_output
(
self
,
batch_out
:
BatchEmbeddingOut
):
async
def
_handle_embedding_output
(
self
,
batch_out
:
BatchEmbeddingOut
put
):
"""Handle batch embedding output from scheduler."""
for
i
,
rid
in
enumerate
(
batch_out
.
rids
):
if
rid
not
in
self
.
rid_to_state
:
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
3c699772
...
...
@@ -94,8 +94,8 @@ from sglang.srt.managers.io_struct import (
VertexGenerateReqInput
,
)
from
sglang.srt.managers.multi_tokenizer_mixin
import
(
MultiTokenizerManager
,
MultiTokenizerRouter
,
TokenizerWorker
,
get_main_process_id
,
monkey_patch_uvicorn_multiprocessing
,
read_from_shared_memory
,
...
...
@@ -127,9 +127,7 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
# Store global states
@
dataclasses
.
dataclass
class
_GlobalState
:
tokenizer_manager
:
Union
[
TokenizerManager
,
MultiTokenizerRouter
,
MultiTokenizerManager
]
tokenizer_manager
:
Union
[
TokenizerManager
,
MultiTokenizerRouter
,
TokenizerWorker
]
template_manager
:
TemplateManager
scheduler_info
:
Dict
...
...
@@ -164,7 +162,7 @@ async def init_multi_tokenizer() -> ServerArgs:
)
# Launch multi-tokenizer manager process
tokenizer_manager
=
Multi
Tokenizer
Manag
er
(
server_args
,
port_args
)
tokenizer_manager
=
Tokenizer
Work
er
(
server_args
,
port_args
)
template_manager
=
TemplateManager
()
template_manager
.
initialize_templates
(
tokenizer_manager
=
tokenizer_manager
,
...
...
python/sglang/srt/lora/lora_manager.py
View file @
3c699772
...
...
@@ -35,7 +35,7 @@ from sglang.srt.lora.utils import (
get_normalized_target_modules
,
get_target_module_name
,
)
from
sglang.srt.managers.io_struct
import
LoRAUpdate
Resul
t
from
sglang.srt.managers.io_struct
import
LoRAUpdate
Outpu
t
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
replace_submodule
...
...
@@ -107,8 +107,8 @@ class LoRAManager:
def
create_lora_update_result
(
self
,
success
:
bool
,
error_message
:
str
=
""
)
->
LoRAUpdate
Resul
t
:
return
LoRAUpdate
Resul
t
(
)
->
LoRAUpdate
Outpu
t
:
return
LoRAUpdate
Outpu
t
(
success
=
success
,
error_message
=
error_message
,
loaded_adapters
=
{
...
...
@@ -117,7 +117,7 @@ class LoRAManager:
},
)
def
load_lora_adapter
(
self
,
lora_ref
:
LoRARef
)
->
LoRAUpdate
Resul
t
:
def
load_lora_adapter
(
self
,
lora_ref
:
LoRARef
)
->
LoRAUpdate
Outpu
t
:
"""
Load a single LoRA adapter from the specified path.
...
...
@@ -174,7 +174,7 @@ class LoRAManager:
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
)
def
unload_lora_adapter
(
self
,
lora_ref
:
LoRARef
)
->
LoRAUpdate
Resul
t
:
def
unload_lora_adapter
(
self
,
lora_ref
:
LoRARef
)
->
LoRAUpdate
Outpu
t
:
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules.
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
3c699772
...
...
@@ -26,11 +26,11 @@ import zmq
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.io_struct
import
(
BatchEmbeddingOut
,
BatchEmbeddingOut
put
,
BatchMultimodalDecodeReq
,
BatchMultimodalOut
,
BatchStrOut
,
BatchTokenIDOut
,
BatchMultimodalOut
put
,
BatchStrOut
put
,
BatchTokenIDOut
put
,
FreezeGCReq
,
MultiTokenizerRegisterReq
,
)
...
...
@@ -101,8 +101,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
self
.
_request_dispatcher
=
TypeBasedDispatcher
(
[
(
BatchEmbeddingOut
,
self
.
handle_batch_embedding_out
),
(
BatchTokenIDOut
,
self
.
handle_batch_token_id_out
),
(
BatchEmbeddingOut
put
,
self
.
handle_batch_embedding_out
),
(
BatchTokenIDOut
put
,
self
.
handle_batch_token_id_out
),
(
BatchMultimodalDecodeReq
,
self
.
handle_multimodal_decode_req
),
(
MultiTokenizerRegisterReq
,
lambda
x
:
x
),
(
FreezeGCReq
,
self
.
handle_freeze_gc_req
),
...
...
@@ -145,11 +145,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
return
output
[:
-
1
]
return
output
def
handle_batch_embedding_out
(
self
,
recv_obj
:
BatchEmbeddingOut
):
def
handle_batch_embedding_out
(
self
,
recv_obj
:
BatchEmbeddingOut
put
):
# If it is embedding model, no detokenization is needed.
return
recv_obj
def
handle_batch_token_id_out
(
self
,
recv_obj
:
BatchTokenIDOut
):
def
handle_batch_token_id_out
(
self
,
recv_obj
:
BatchTokenIDOut
put
):
bs
=
len
(
recv_obj
.
rids
)
# Initialize decode status
...
...
@@ -224,7 +224,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
s
.
sent_offset
=
len
(
output_str
)
output_strs
.
append
(
incremental_output
)
return
BatchStrOut
(
return
BatchStrOut
put
(
rids
=
recv_obj
.
rids
,
finished_reasons
=
recv_obj
.
finished_reasons
,
output_strs
=
output_strs
,
...
...
@@ -252,7 +252,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
def
handle_multimodal_decode_req
(
self
,
recv_obj
:
BatchMultimodalDecodeReq
):
outputs
=
self
.
tokenizer
.
detokenize
(
recv_obj
)
return
BatchMultimodalOut
(
return
BatchMultimodalOut
put
(
rids
=
recv_obj
.
rids
,
finished_reasons
=
recv_obj
.
finished_reasons
,
outputs
=
outputs
,
...
...
python/sglang/srt/managers/io_struct.py
View file @
3c699772
...
...
@@ -18,6 +18,7 @@ processes (TokenizerManager, DetokenizerManager, Scheduler).
import
copy
import
uuid
from
abc
import
ABC
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
...
...
@@ -36,10 +37,32 @@ else:
# Parameters for a session
@
dataclass
class
BaseReq
(
ABC
):
rid
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
field
(
default
=
None
,
kw_only
=
True
)
def
regenerate_rid
(
self
):
"""Generate a new request ID and return it."""
if
isinstance
(
self
.
rid
,
list
):
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
len
(
self
.
rid
))]
else
:
self
.
rid
=
uuid
.
uuid4
().
hex
return
self
.
rid
@
dataclass
class
BaseBatchReq
(
ABC
):
rids
:
Optional
[
List
[
str
]]
=
field
(
default
=
None
,
kw_only
=
True
)
def
regenerate_rids
(
self
):
"""Generate new request IDs and return them."""
self
.
rids
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
len
(
self
.
rids
))]
return
self
.
rids
@
dataclass
class
SessionParams
:
id
:
Optional
[
str
]
=
None
rid
:
Optional
[
str
]
=
None
offset
:
Optional
[
int
]
=
None
replace
:
Optional
[
bool
]
=
None
drop_previous_output
:
Optional
[
bool
]
=
None
...
...
@@ -63,7 +86,7 @@ MultimodalDataInputFormat = Union[
@
dataclass
class
GenerateReqInput
:
class
GenerateReqInput
(
BaseReq
)
:
# The input prompt. It can be a single prompt or a batch of prompts.
text
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The token ids for text; one can specify either text or input_ids
...
...
@@ -83,8 +106,6 @@ class GenerateReqInput:
audio_data
:
Optional
[
MultimodalDataInputFormat
]
=
None
# The sampling_params. See descriptions below.
sampling_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
# The request id.
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Whether to return logprobs.
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
# If return logprobs, the start location in the prompt for returning logprobs.
...
...
@@ -491,11 +512,6 @@ class GenerateReqInput:
):
raise
ValueError
(
"Session params must be a dict or a list of dicts."
)
def
regenerate_rid
(
self
):
"""Generate a new request ID and return it."""
self
.
rid
=
uuid
.
uuid4
().
hex
return
self
.
rid
def
__getitem__
(
self
,
i
):
return
GenerateReqInput
(
text
=
self
.
text
[
i
]
if
self
.
text
is
not
None
else
None
,
...
...
@@ -558,9 +574,7 @@ class GenerateReqInput:
@
dataclass
class
TokenizedGenerateReqInput
:
# The request id
rid
:
str
class
TokenizedGenerateReqInput
(
BaseReq
):
# The input text
input_text
:
str
# The input token ids
...
...
@@ -625,7 +639,7 @@ class TokenizedGenerateReqInput:
@
dataclass
class
BatchTokenizedGenerateReqInput
:
class
BatchTokenizedGenerateReqInput
(
BaseBatchReq
)
:
# The batch of tokenized requests
batch
:
List
[
TokenizedGenerateReqInput
]
...
...
@@ -640,7 +654,7 @@ class BatchTokenizedGenerateReqInput:
@
dataclass
class
EmbeddingReqInput
:
class
EmbeddingReqInput
(
BaseReq
)
:
# The input prompt. It can be a single prompt or a batch of prompts.
text
:
Optional
[
Union
[
List
[
List
[
str
]],
List
[
str
],
str
]]
=
None
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
...
...
@@ -656,8 +670,6 @@ class EmbeddingReqInput:
audio_data
:
Optional
[
MultimodalDataInputFormat
]
=
None
# The token ids for text; one can either specify text or input_ids.
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
# The request id.
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Dummy sampling params for compatibility
sampling_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
# Dummy input embeds for compatibility
...
...
@@ -728,10 +740,6 @@ class EmbeddingReqInput:
for
i
in
range
(
self
.
batch_size
):
self
.
sampling_params
[
i
][
"max_new_tokens"
]
=
0
def
regenerate_rid
(
self
):
self
.
rid
=
uuid
.
uuid4
().
hex
return
self
.
rid
def
contains_mm_input
(
self
)
->
bool
:
return
(
has_valid_data
(
self
.
image_data
)
...
...
@@ -760,9 +768,7 @@ class EmbeddingReqInput:
@
dataclass
class
TokenizedEmbeddingReqInput
:
# The request id
rid
:
str
class
TokenizedEmbeddingReqInput
(
BaseReq
):
# The input text
input_text
:
str
# The input token ids
...
...
@@ -780,7 +786,7 @@ class TokenizedEmbeddingReqInput:
@
dataclass
class
BatchTokenizedEmbeddingReqInput
:
class
BatchTokenizedEmbeddingReqInput
(
BaseBatchReq
)
:
# The batch of tokenized embedding requests
batch
:
List
[
TokenizedEmbeddingReqInput
]
...
...
@@ -795,9 +801,7 @@ class BatchTokenizedEmbeddingReqInput:
@
dataclass
class
BatchTokenIDOut
:
# The request id
rids
:
List
[
str
]
class
BatchTokenIDOutput
(
BaseBatchReq
):
# The finish reason
finished_reasons
:
List
[
BaseFinishReason
]
# For incremental decoding
...
...
@@ -842,7 +846,7 @@ class BatchTokenIDOut:
@
dataclass
class
BatchMultimodalDecodeReq
:
class
BatchMultimodalDecodeReq
(
BaseBatchReq
)
:
decoded_ids
:
List
[
int
]
input_token_logprobs_val
:
List
[
float
]
input_token_logprobs_idx
:
List
[
int
]
...
...
@@ -854,8 +858,6 @@ class BatchMultimodalDecodeReq:
image_resolutions
:
List
[
List
[
int
]]
resize_image_resolutions
:
List
[
List
[
int
]]
# The request id
rids
:
List
[
str
]
finished_reasons
:
List
[
BaseFinishReason
]
# Token counts
...
...
@@ -871,9 +873,7 @@ class BatchMultimodalDecodeReq:
@
dataclass
class
BatchStrOut
:
# The request id
rids
:
List
[
str
]
class
BatchStrOutput
(
BaseBatchReq
):
# The finish reason
finished_reasons
:
List
[
dict
]
# The output decoded strings
...
...
@@ -909,9 +909,7 @@ class BatchStrOut:
@
dataclass
class
BatchMultimodalOut
:
# The request id
rids
:
List
[
str
]
class
BatchMultimodalOutput
(
BaseBatchReq
):
# The finish reason
finished_reasons
:
List
[
dict
]
decoded_ids
:
List
[
List
[
int
]]
...
...
@@ -936,9 +934,7 @@ class BatchMultimodalOut:
@
dataclass
class
BatchEmbeddingOut
:
# The request id
rids
:
List
[
str
]
class
BatchEmbeddingOutput
(
BaseBatchReq
):
# The finish reason
finished_reasons
:
List
[
BaseFinishReason
]
# The output embedding
...
...
@@ -952,27 +948,27 @@ class BatchEmbeddingOut:
@
dataclass
class
ClearHiCacheReqInput
:
class
ClearHiCacheReqInput
(
BaseReq
)
:
pass
@
dataclass
class
ClearHiCacheReqOutput
:
class
ClearHiCacheReqOutput
(
BaseReq
)
:
success
:
bool
@
dataclass
class
FlushCacheReqInput
:
class
FlushCacheReqInput
(
BaseReq
)
:
pass
@
dataclass
class
FlushCacheReqOutput
:
class
FlushCacheReqOutput
(
BaseReq
)
:
success
:
bool
@
dataclass
class
UpdateWeightFromDiskReqInput
:
class
UpdateWeightFromDiskReqInput
(
BaseReq
)
:
# The model path with the new weights
model_path
:
str
# The format to load the weights
...
...
@@ -990,7 +986,7 @@ class UpdateWeightFromDiskReqInput:
@
dataclass
class
UpdateWeightFromDiskReqOutput
:
class
UpdateWeightFromDiskReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
# Number of paused requests during weight sync.
...
...
@@ -998,7 +994,7 @@ class UpdateWeightFromDiskReqOutput:
@
dataclass
class
UpdateWeightsFromDistributedReqInput
:
class
UpdateWeightsFromDistributedReqInput
(
BaseReq
)
:
names
:
List
[
str
]
dtypes
:
List
[
str
]
shapes
:
List
[
List
[
int
]]
...
...
@@ -1013,13 +1009,13 @@ class UpdateWeightsFromDistributedReqInput:
@
dataclass
class
UpdateWeightsFromDistributedReqOutput
:
class
UpdateWeightsFromDistributedReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
@
dataclass
class
UpdateWeightsFromTensorReqInput
:
class
UpdateWeightsFromTensorReqInput
(
BaseReq
)
:
"""Update model weights from tensor input.
- Tensors are serialized for transmission
...
...
@@ -1038,13 +1034,13 @@ class UpdateWeightsFromTensorReqInput:
@
dataclass
class
UpdateWeightsFromTensorReqOutput
:
class
UpdateWeightsFromTensorReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
@
dataclass
class
InitWeightsSendGroupForRemoteInstanceReqInput
:
class
InitWeightsSendGroupForRemoteInstanceReqInput
(
BaseReq
)
:
# The master address
master_address
:
str
# The ports for each rank's communication group
...
...
@@ -1060,13 +1056,13 @@ class InitWeightsSendGroupForRemoteInstanceReqInput:
@
dataclass
class
InitWeightsSendGroupForRemoteInstanceReqOutput
:
class
InitWeightsSendGroupForRemoteInstanceReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
@
dataclass
class
SendWeightsToRemoteInstanceReqInput
:
class
SendWeightsToRemoteInstanceReqInput
(
BaseReq
)
:
# The master address
master_address
:
str
# The ports for each rank's communication group
...
...
@@ -1076,13 +1072,13 @@ class SendWeightsToRemoteInstanceReqInput:
@
dataclass
class
SendWeightsToRemoteInstanceReqOutput
:
class
SendWeightsToRemoteInstanceReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
@
dataclass
class
InitWeightsUpdateGroupReqInput
:
class
InitWeightsUpdateGroupReqInput
(
BaseReq
)
:
# The master address
master_address
:
str
# The master port
...
...
@@ -1098,24 +1094,24 @@ class InitWeightsUpdateGroupReqInput:
@
dataclass
class
InitWeightsUpdateGroupReqOutput
:
class
InitWeightsUpdateGroupReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
@
dataclass
class
DestroyWeightsUpdateGroupReqInput
:
class
DestroyWeightsUpdateGroupReqInput
(
BaseReq
)
:
group_name
:
str
=
"weight_update_group"
@
dataclass
class
DestroyWeightsUpdateGroupReqOutput
:
class
DestroyWeightsUpdateGroupReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
@
dataclass
class
UpdateWeightVersionReqInput
:
class
UpdateWeightVersionReqInput
(
BaseReq
)
:
# The new weight version
new_version
:
str
# Whether to abort all running requests before updating
...
...
@@ -1123,89 +1119,87 @@ class UpdateWeightVersionReqInput:
@
dataclass
class
GetWeightsByNameReqInput
:
class
GetWeightsByNameReqInput
(
BaseReq
)
:
name
:
str
truncate_size
:
int
=
100
@
dataclass
class
GetWeightsByNameReqOutput
:
class
GetWeightsByNameReqOutput
(
BaseReq
)
:
parameter
:
list
@
dataclass
class
ReleaseMemoryOccupationReqInput
:
class
ReleaseMemoryOccupationReqInput
(
BaseReq
)
:
# Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
tags
:
Optional
[
List
[
str
]]
=
None
@
dataclass
class
ReleaseMemoryOccupationReqOutput
:
class
ReleaseMemoryOccupationReqOutput
(
BaseReq
)
:
pass
@
dataclass
class
ResumeMemoryOccupationReqInput
:
class
ResumeMemoryOccupationReqInput
(
BaseReq
)
:
# Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
tags
:
Optional
[
List
[
str
]]
=
None
@
dataclass
class
ResumeMemoryOccupationReqOutput
:
class
ResumeMemoryOccupationReqOutput
(
BaseReq
)
:
pass
@
dataclass
class
SlowDownReqInput
:
class
SlowDownReqInput
(
BaseReq
)
:
forward_sleep_time
:
Optional
[
float
]
@
dataclass
class
SlowDownReqOutput
:
class
SlowDownReqOutput
(
BaseReq
)
:
pass
@
dataclass
class
AbortReq
:
# The request id
rid
:
str
=
""
class
AbortReq
(
BaseReq
):
# Whether to abort all requests
abort_all
:
bool
=
False
# The finished reason data
finished_reason
:
Optional
[
Dict
[
str
,
Any
]]
=
None
abort_reason
:
Optional
[
str
]
=
None
# used in MultiTokenzierManager mode
rids
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
def
__post_init__
(
self
):
self
.
rids
=
self
.
rid
# FIXME: This is a hack to keep the same with the old code
if
self
.
rid
is
None
:
self
.
rid
=
""
@
dataclass
class
GetInternalStateReq
:
class
GetInternalStateReq
(
BaseReq
)
:
pass
@
dataclass
class
GetInternalStateReqOutput
:
class
GetInternalStateReqOutput
(
BaseReq
)
:
internal_state
:
Dict
[
Any
,
Any
]
@
dataclass
class
SetInternalStateReq
:
class
SetInternalStateReq
(
BaseReq
)
:
server_args
:
Dict
[
str
,
Any
]
@
dataclass
class
SetInternalStateReqOutput
:
class
SetInternalStateReqOutput
(
BaseReq
)
:
updated
:
bool
server_args
:
Dict
[
str
,
Any
]
@
dataclass
class
ProfileReqInput
:
class
ProfileReqInput
(
BaseReq
)
:
# The output directory
output_dir
:
Optional
[
str
]
=
None
# If set, it profile as many as this number of steps.
...
...
@@ -1225,7 +1219,7 @@ class ProfileReqType(Enum):
@
dataclass
class
ProfileReq
:
class
ProfileReq
(
BaseReq
)
:
type
:
ProfileReqType
output_dir
:
Optional
[
str
]
=
None
start_step
:
Optional
[
int
]
=
None
...
...
@@ -1238,18 +1232,18 @@ class ProfileReq:
@
dataclass
class
ProfileReqOutput
:
class
ProfileReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
@
dataclass
class
FreezeGCReq
:
class
FreezeGCReq
(
BaseReq
)
:
pass
@
dataclass
class
ConfigureLoggingReq
:
class
ConfigureLoggingReq
(
BaseReq
)
:
log_requests
:
Optional
[
bool
]
=
None
log_requests_level
:
Optional
[
int
]
=
None
dump_requests_folder
:
Optional
[
str
]
=
None
...
...
@@ -1258,35 +1252,39 @@ class ConfigureLoggingReq:
@
dataclass
class
OpenSessionReqInput
:
class
OpenSessionReqInput
(
BaseReq
)
:
capacity_of_str_len
:
int
session_id
:
Optional
[
str
]
=
None
@
dataclass
class
CloseSessionReqInput
:
class
CloseSessionReqInput
(
BaseReq
)
:
session_id
:
str
@
dataclass
class
OpenSessionReqOutput
:
class
OpenSessionReqOutput
(
BaseReq
)
:
session_id
:
Optional
[
str
]
success
:
bool
@
dataclass
class
HealthCheckOutput
:
class
HealthCheckOutput
(
BaseReq
)
:
pass
class
ExpertDistributionReq
(
Enum
):
class
ExpertDistributionReq
Type
(
Enum
):
START_RECORD
=
1
STOP_RECORD
=
2
DUMP_RECORD
=
3
class
ExpertDistributionReq
(
BaseReq
):
action
:
ExpertDistributionReqType
@
dataclass
class
ExpertDistributionReqOutput
:
class
ExpertDistributionReqOutput
(
BaseReq
)
:
pass
...
...
@@ -1304,7 +1302,7 @@ class Tool:
@
dataclass
class
ParseFunctionCallReq
:
class
ParseFunctionCallReq
(
BaseReq
)
:
text
:
str
# The text to parse.
tools
:
List
[
Tool
]
=
field
(
default_factory
=
list
...
...
@@ -1315,31 +1313,31 @@ class ParseFunctionCallReq:
@
dataclass
class
SeparateReasoningReqInput
:
class
SeparateReasoningReqInput
(
BaseReq
)
:
text
:
str
# The text to parse.
reasoning_parser
:
str
# Specify the parser type, e.g., "deepseek-r1".
@
dataclass
class
VertexGenerateReqInput
:
class
VertexGenerateReqInput
(
BaseReq
)
:
instances
:
List
[
dict
]
parameters
:
Optional
[
dict
]
=
None
@
dataclass
class
RpcReqInput
:
class
RpcReqInput
(
BaseReq
)
:
method
:
str
parameters
:
Optional
[
Dict
]
=
None
@
dataclass
class
RpcReqOutput
:
class
RpcReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
@
dataclass
class
LoadLoRAAdapterReqInput
:
class
LoadLoRAAdapterReqInput
(
BaseReq
)
:
# The name of the lora module to newly loaded.
lora_name
:
str
# The path of loading.
...
...
@@ -1359,7 +1357,7 @@ class LoadLoRAAdapterReqInput:
@
dataclass
class
UnloadLoRAAdapterReqInput
:
class
UnloadLoRAAdapterReqInput
(
BaseReq
)
:
# The name of lora module to unload.
lora_name
:
str
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
...
...
@@ -1373,23 +1371,23 @@ class UnloadLoRAAdapterReqInput:
@
dataclass
class
LoRAUpdate
Result
:
class
LoRAUpdate
Output
(
BaseReq
)
:
success
:
bool
error_message
:
Optional
[
str
]
=
None
loaded_adapters
:
Optional
[
Dict
[
str
,
LoRARef
]]
=
None
LoadLoRAAdapterReqOutput
=
UnloadLoRAAdapterReqOutput
=
LoRAUpdate
Resul
t
LoadLoRAAdapterReqOutput
=
UnloadLoRAAdapterReqOutput
=
LoRAUpdate
Outpu
t
@
dataclass
class
MultiTokenizerRegisterReq
:
rids
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
class
MultiTokenizerRegisterReq
(
BaseBatchReq
):
ipc_name
:
Optional
[
str
]
=
None
@
dataclass
class
MultiTokenizerWrapper
:
# FIXME(lsyin): remove this
worker_id
:
int
obj
:
Optional
[
Any
]
=
None
...
...
@@ -1400,17 +1398,17 @@ class BlockReqType(Enum):
@
dataclass
class
BlockReqInput
:
class
BlockReqInput
(
BaseReq
)
:
type
:
BlockReqType
@
dataclass
class
GetLoadReqInput
:
class
GetLoadReqInput
(
BaseReq
)
:
pass
@
dataclass
class
GetLoadReqOutput
:
class
GetLoadReqOutput
(
BaseReq
)
:
dp_rank
:
int
num_reqs
:
int
num_waiting_reqs
:
int
...
...
@@ -1418,5 +1416,31 @@ class GetLoadReqOutput:
@
dataclass
class
WatchLoadUpdateReq
:
class
WatchLoadUpdateReq
(
BaseReq
)
:
loads
:
List
[
GetLoadReqOutput
]
def
_check_all_req_types
():
"""A helper function to check all request types are defined in this file."""
import
inspect
import
sys
all_classes
=
inspect
.
getmembers
(
sys
.
modules
[
__name__
],
inspect
.
isclass
)
for
class_type
in
all_classes
:
# check its name
name
=
class_type
[
0
]
is_io_struct
=
(
name
.
endswith
(
"Req"
)
or
name
.
endswith
(
"Input"
)
or
name
.
endswith
(
"Output"
)
)
is_base_req
=
issubclass
(
class_type
[
1
],
BaseReq
)
or
issubclass
(
class_type
[
1
],
BaseBatchReq
)
if
is_io_struct
and
not
is_base_req
:
raise
ValueError
(
f
"
{
name
}
is not a subclass of BaseReq or BaseBatchReq."
)
if
is_base_req
and
not
is_io_struct
:
raise
ValueError
(
f
"
{
name
}
is a subclass of BaseReq but not follow the naming convention."
)
_check_all_req_types
()
python/sglang/srt/managers/multi_tokenizer_mixin.py
View file @
3c699772
...
...
@@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""M
ultiTokenizerMixin is a class that provides nesscary method
s for
M
ulti
TokenizerManager and DetokenizerManager.
"""
"""M
ixin class and util
s for
m
ulti
-http-worker mode
"""
import
asyncio
import
logging
import
multiprocessing
as
multiprocessing
...
...
@@ -30,10 +30,10 @@ import zmq.asyncio
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
,
TransferBackend
from
sglang.srt.managers.disagg_service
import
start_disagg_service
from
sglang.srt.managers.io_struct
import
(
BatchEmbeddingOut
,
BatchMultimodalOut
,
BatchStrOut
,
BatchTokenIDOut
,
BatchEmbeddingOut
put
,
BatchMultimodalOut
put
,
BatchStrOut
put
,
BatchTokenIDOut
put
,
MultiTokenizerRegisterReq
,
MultiTokenizerWrapper
,
)
...
...
@@ -83,8 +83,8 @@ class SocketMapping:
def
_handle_output_by_index
(
output
,
i
):
"""NOTE: A maintainable method is better here."""
if
isinstance
(
output
,
BatchTokenIDOut
):
new_output
=
BatchTokenIDOut
(
if
isinstance
(
output
,
BatchTokenIDOut
put
):
new_output
=
BatchTokenIDOut
put
(
rids
=
[
output
.
rids
[
i
]],
finished_reasons
=
(
[
output
.
finished_reasons
[
i
]]
...
...
@@ -198,8 +198,8 @@ def _handle_output_by_index(output, i):
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
)
elif
isinstance
(
output
,
BatchEmbeddingOut
):
new_output
=
BatchEmbeddingOut
(
elif
isinstance
(
output
,
BatchEmbeddingOut
put
):
new_output
=
BatchEmbeddingOut
put
(
rids
=
[
output
.
rids
[
i
]],
finished_reasons
=
(
[
output
.
finished_reasons
[
i
]]
...
...
@@ -216,8 +216,8 @@ def _handle_output_by_index(output, i):
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
)
elif
isinstance
(
output
,
BatchStrOut
):
new_output
=
BatchStrOut
(
elif
isinstance
(
output
,
BatchStrOut
put
):
new_output
=
BatchStrOut
put
(
rids
=
[
output
.
rids
[
i
]],
finished_reasons
=
(
[
output
.
finished_reasons
[
i
]]
...
...
@@ -314,8 +314,8 @@ def _handle_output_by_index(output, i):
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
)
elif
isinstance
(
output
,
BatchMultimodalOut
):
new_output
=
BatchMultimodalOut
(
elif
isinstance
(
output
,
BatchMultimodalOut
put
):
new_output
=
BatchMultimodalOut
put
(
rids
=
[
output
.
rids
[
i
]],
finished_reasons
=
(
[
output
.
finished_reasons
[
i
]]
...
...
@@ -343,7 +343,7 @@ def _handle_output_by_index(output, i):
class
MultiHttpWorkerDetokenizerMixin
:
"""Mixin class for
MultiTokenizerManager and
DetokenizerManager"""
"""Mixin class for DetokenizerManager"""
def
get_worker_ids_from_req_rids
(
self
,
rids
):
if
isinstance
(
rids
,
list
):
...
...
@@ -386,7 +386,7 @@ class MultiHttpWorkerDetokenizerMixin:
class
MultiTokenizerRouter
:
"""A router to receive requests from
Multi
Tokenizer
Manag
er"""
"""A router to receive requests from Tokenizer
Work
er"""
def
__init__
(
self
,
...
...
@@ -454,8 +454,8 @@ class MultiTokenizerRouter:
self
.
socket_mapping
.
send_output
(
worker_id
,
new_recv_obj
)
class
Multi
Tokenizer
Manag
er
(
TokenizerManager
):
"""
Multi Process Tokenizer Manager that tokenizes the text.
"""
class
Tokenizer
Work
er
(
TokenizerManager
):
"""
Tokenizer Worker in multi-http-worker mode
"""
def
__init__
(
self
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
3c699772
...
...
@@ -78,6 +78,7 @@ from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqInput
,
ExpertDistributionReq
,
ExpertDistributionReqOutput
,
ExpertDistributionReqType
,
FlushCacheReqInput
,
FlushCacheReqOutput
,
FreezeGCReq
,
...
...
@@ -1487,12 +1488,12 @@ class Scheduler(
req
.
priority
=
-
sys
.
maxsize
-
1
elif
not
self
.
enable_priority_scheduling
and
req
.
priority
is
not
None
:
abort_req
=
AbortReq
(
req
.
rid
,
finished_reason
=
{
"type"
:
"abort"
,
"status_code"
:
HTTPStatus
.
SERVICE_UNAVAILABLE
,
"message"
:
"Using priority is disabled for this server. Please send a new request without a priority."
,
},
rid
=
req
.
rid
,
)
self
.
send_to_tokenizer
.
send_pyobj
(
abort_req
)
...
...
@@ -1528,12 +1529,12 @@ class Scheduler(
self
.
send_to_tokenizer
.
send_pyobj
(
AbortReq
(
req_to_abort
.
rid
,
finished_reason
=
{
"type"
:
"abort"
,
"status_code"
:
HTTPStatus
.
SERVICE_UNAVAILABLE
,
"message"
:
message
,
},
rid
=
req_to_abort
.
rid
,
)
)
return
req_to_abort
.
rid
==
recv_req
.
rid
...
...
@@ -2005,7 +2006,7 @@ class Scheduler(
self
.
new_token_ratio
=
new_token_ratio
for
req
in
reqs_to_abort
:
self
.
send_to_tokenizer
.
send_pyobj
(
AbortReq
(
req
.
rid
,
abort_reason
=
req
.
to_abort_message
)
AbortReq
(
abort_reason
=
req
.
to_abort_message
,
rid
=
req
.
rid
)
)
logger
.
info
(
...
...
@@ -2575,7 +2576,7 @@ class Scheduler(
if
self
.
enable_hicache_storage
:
# to release prefetch events associated with the request
self
.
tree_cache
.
release_aborted_request
(
req
.
rid
)
self
.
send_to_tokenizer
.
send_pyobj
(
AbortReq
(
req
.
rid
))
self
.
send_to_tokenizer
.
send_pyobj
(
AbortReq
(
rid
=
req
.
rid
))
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
tree_cache
.
cache_finished_req
(
req
)
...
...
@@ -2687,11 +2688,12 @@ class Scheduler(
return
SlowDownReqOutput
()
def
expert_distribution_handle
(
self
,
recv_req
:
ExpertDistributionReq
):
if
recv_req
==
ExpertDistributionReq
.
START_RECORD
:
action
=
recv_req
.
action
if
action
==
ExpertDistributionReqType
.
START_RECORD
:
get_global_expert_distribution_recorder
().
start_record
()
elif
recv_req
==
ExpertDistributionReq
.
STOP_RECORD
:
elif
action
==
ExpertDistributionReq
Type
.
STOP_RECORD
:
get_global_expert_distribution_recorder
().
stop_record
()
elif
recv_req
==
ExpertDistributionReq
.
DUMP_RECORD
:
elif
action
==
ExpertDistributionReq
Type
.
DUMP_RECORD
:
get_global_expert_distribution_recorder
().
dump_record
()
else
:
raise
ValueError
(
f
"Unrecognized ExpertDistributionReq value:
{
recv_req
=
}
"
)
...
...
@@ -2774,7 +2776,8 @@ class IdleSleeper:
def
is_health_check_generate_req
(
recv_req
):
return
getattr
(
recv_req
,
"rid"
,
""
).
startswith
(
"HEALTH_CHECK"
)
rid
=
getattr
(
recv_req
,
"rid"
,
None
)
return
rid
is
not
None
and
rid
.
startswith
(
"HEALTH_CHECK"
)
def
is_work_request
(
recv_req
):
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
3c699772
...
...
@@ -9,7 +9,11 @@ import torch
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
AbortReq
,
BatchEmbeddingOut
,
BatchTokenIDOut
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOutput
,
BatchTokenIDOutput
,
)
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
,
Req
,
ScheduleBatch
if
TYPE_CHECKING
:
...
...
@@ -140,7 +144,7 @@ class SchedulerOutputProcessorMixin:
logger
.
error
(
f
"Grammar accept_token failed for req
{
req
.
rid
}
with token
{
next_token_id
}
:
{
e
}
"
)
self
.
abort_request
(
AbortReq
(
req
.
rid
))
self
.
abort_request
(
AbortReq
(
rid
=
req
.
rid
))
req
.
grammar
.
finished
=
req
.
finished
()
else
:
# being chunked reqs' prefill is not finished
...
...
@@ -292,7 +296,7 @@ class SchedulerOutputProcessorMixin:
logger
.
error
(
f
"Grammar accept_token failed for req
{
req
.
rid
}
with token
{
next_token_id
}
:
{
e
}
"
)
self
.
abort_request
(
AbortReq
(
req
.
rid
))
self
.
abort_request
(
AbortReq
(
rid
=
req
.
rid
))
req
.
grammar
.
finished
=
req
.
finished
()
self
.
set_next_batch_sampling_info_done
(
batch
)
...
...
@@ -714,8 +718,7 @@ class SchedulerOutputProcessorMixin:
return
self
.
send_to_detokenizer
.
send_pyobj
(
BatchTokenIDOut
(
rids
,
BatchTokenIDOutput
(
finished_reasons
,
decoded_texts
,
decode_ids_list
,
...
...
@@ -741,6 +744,7 @@ class SchedulerOutputProcessorMixin:
output_token_ids_logprobs_val
,
output_token_ids_logprobs_idx
,
output_hidden_states
,
rids
=
rids
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
)
...
...
@@ -761,12 +765,12 @@ class SchedulerOutputProcessorMixin:
prompt_tokens
.
append
(
len
(
req
.
origin_input_ids
))
cached_tokens
.
append
(
req
.
cached_tokens
)
self
.
send_to_detokenizer
.
send_pyobj
(
BatchEmbeddingOut
(
rids
,
BatchEmbeddingOutput
(
finished_reasons
,
embeddings
,
prompt_tokens
,
cached_tokens
,
rids
=
rids
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
)
...
...
python/sglang/srt/managers/tokenizer_communicator_mixin.py
View file @
3c699772
...
...
@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqOutput
,
ExpertDistributionReq
,
ExpertDistributionReqOutput
,
ExpertDistributionReqType
,
FlushCacheReqInput
,
FlushCacheReqOutput
,
GetInternalStateReq
,
...
...
@@ -44,7 +45,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqOutput
,
LoadLoRAAdapterReqInput
,
LoadLoRAAdapterReqOutput
,
LoRAUpdate
Resul
t
,
LoRAUpdate
Outpu
t
,
MultiTokenizerWrapper
,
OpenSessionReqInput
,
ProfileReq
,
...
...
@@ -276,7 +277,7 @@ class TokenizerCommunicatorMixin:
self
.
expert_distribution_communicator
.
handle_recv
,
),
(
LoRAUpdate
Resul
t
,
LoRAUpdate
Outpu
t
,
self
.
update_lora_adapter_communicator
.
handle_recv
,
),
(
...
...
@@ -335,15 +336,18 @@ class TokenizerCommunicatorMixin:
async
def
start_expert_distribution_record
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
START_RECORD
)
req
=
ExpertDistributionReq
(
action
=
ExpertDistributionReqType
.
START_RECORD
)
await
self
.
expert_distribution_communicator
(
req
)
async
def
stop_expert_distribution_record
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
STOP_RECORD
)
req
=
ExpertDistributionReq
(
action
=
ExpertDistributionReqType
.
STOP_RECORD
)
await
self
.
expert_distribution_communicator
(
req
)
async
def
dump_expert_distribution_record
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
DUMP_RECORD
)
req
=
ExpertDistributionReq
(
action
=
ExpertDistributionReqType
.
DUMP_RECORD
)
await
self
.
expert_distribution_communicator
(
req
)
async
def
init_weights_update_group
(
self
:
TokenizerManager
,
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
3c699772
...
...
@@ -48,18 +48,17 @@ from sglang.srt.hf_transformers_utils import (
get_tokenizer
,
get_tokenizer_from_processor
,
)
from
sglang.srt.lora.lora_registry
import
LoRARef
,
LoRARegistry
from
sglang.srt.lora.lora_registry
import
LoRARegistry
from
sglang.srt.managers.async_dynamic_batch_tokenizer
import
AsyncDynamicbatchTokenizer
from
sglang.srt.managers.disagg_service
import
start_disagg_service
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOut
,
BatchMultimodalOut
,
BatchStrOut
,
BatchTokenIDOut
,
BatchEmbeddingOut
put
,
BatchMultimodalOut
put
,
BatchStrOut
put
,
BatchTokenIDOut
put
,
BatchTokenizedEmbeddingReqInput
,
BatchTokenizedGenerateReqInput
,
CloseSessionReqInput
,
ConfigureLoggingReq
,
EmbeddingReqInput
,
FreezeGCReq
,
...
...
@@ -67,7 +66,6 @@ from sglang.srt.managers.io_struct import (
GetLoadReqInput
,
HealthCheckOutput
,
MultiTokenizerWrapper
,
OpenSessionReqInput
,
OpenSessionReqOutput
,
SessionParams
,
TokenizedEmbeddingReqInput
,
...
...
@@ -341,10 +339,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
[
(
(
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
BatchMultimodalOut
,
BatchStrOut
put
,
BatchEmbeddingOut
put
,
BatchTokenIDOut
put
,
BatchMultimodalOut
put
,
),
self
.
_handle_batch_output
,
),
...
...
@@ -716,7 +714,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
tokenized_obj
=
TokenizedGenerateReqInput
(
obj
.
rid
,
input_text
,
input_ids
,
mm_inputs
,
...
...
@@ -726,6 +723,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
obj
.
top_logprobs_num
,
obj
.
token_ids_logprob
,
obj
.
stream
,
rid
=
obj
.
rid
,
bootstrap_host
=
obj
.
bootstrap_host
,
bootstrap_port
=
obj
.
bootstrap_port
,
bootstrap_room
=
obj
.
bootstrap_room
,
...
...
@@ -740,12 +738,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
elif
isinstance
(
obj
,
EmbeddingReqInput
):
tokenized_obj
=
TokenizedEmbeddingReqInput
(
obj
.
rid
,
input_text
,
input_ids
,
mm_inputs
,
token_type_ids
,
sampling_params
,
rid
=
obj
.
rid
,
priority
=
obj
.
priority
,
)
...
...
@@ -1038,7 +1036,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
def
abort_request
(
self
,
rid
:
str
=
""
,
abort_all
:
bool
=
False
):
if
not
abort_all
and
rid
not
in
self
.
rid_to_state
:
return
req
=
AbortReq
(
rid
,
abort_all
)
req
=
AbortReq
(
rid
=
rid
,
abort_all
=
abort_all
)
self
.
send_to_scheduler
.
send_pyobj
(
req
)
if
self
.
enable_metrics
:
# TODO: also use custom_labels from the request
...
...
@@ -1303,7 +1301,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
def
_handle_batch_output
(
self
,
recv_obj
:
Union
[
BatchStrOut
,
BatchEmbeddingOut
,
BatchMultimodalOut
,
BatchTokenIDOut
BatchStrOutput
,
BatchEmbeddingOutput
,
BatchMultimodalOutput
,
BatchTokenIDOutput
,
],
):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
...
...
@@ -1337,7 +1338,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
i
,
)
if
not
isinstance
(
recv_obj
,
BatchEmbeddingOut
):
if
not
isinstance
(
recv_obj
,
BatchEmbeddingOut
put
):
meta_info
.
update
(
{
"completion_tokens"
:
recv_obj
.
completion_tokens
[
i
],
...
...
@@ -1348,7 +1349,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if
getattr
(
recv_obj
,
"output_hidden_states"
,
None
):
meta_info
[
"hidden_states"
]
=
recv_obj
.
output_hidden_states
[
i
]
if
isinstance
(
recv_obj
,
BatchStrOut
):
if
isinstance
(
recv_obj
,
BatchStrOut
put
):
state
.
text
+=
recv_obj
.
output_strs
[
i
]
if
state
.
obj
.
stream
:
state
.
output_ids
.
extend
(
recv_obj
.
output_ids
[
i
])
...
...
@@ -1363,7 +1364,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"output_ids"
:
output_token_ids
,
"meta_info"
:
meta_info
,
}
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
):
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
put
):
if
self
.
server_args
.
stream_output
and
state
.
obj
.
stream
:
state
.
output_ids
.
extend
(
recv_obj
.
output_ids
[
i
])
output_token_ids
=
state
.
output_ids
[
state
.
last_output_offset
:]
...
...
@@ -1376,10 +1377,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"output_ids"
:
output_token_ids
,
"meta_info"
:
meta_info
,
}
elif
isinstance
(
recv_obj
,
BatchMultimodalOut
):
elif
isinstance
(
recv_obj
,
BatchMultimodalOut
put
):
raise
NotImplementedError
(
"BatchMultimodalOut not implemented"
)
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
put
)
out_dict
=
{
"embedding"
:
recv_obj
.
embeddings
[
i
],
"meta_info"
:
meta_info
,
...
...
@@ -1418,7 +1419,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
top_logprobs_num
:
int
,
token_ids_logprob
:
List
[
int
],
return_text_in_logprobs
:
bool
,
recv_obj
:
BatchStrOut
,
recv_obj
:
BatchStrOut
put
,
recv_obj_index
:
int
,
):
if
recv_obj
.
input_token_logprobs_val
is
None
:
...
...
@@ -1536,7 +1537,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
ret
.
append
(
None
)
return
ret
def
collect_metrics
(
self
,
state
:
ReqState
,
recv_obj
:
BatchStrOut
,
i
:
int
):
def
collect_metrics
(
self
,
state
:
ReqState
,
recv_obj
:
BatchStrOut
put
,
i
:
int
):
completion_tokens
=
(
recv_obj
.
completion_tokens
[
i
]
if
getattr
(
recv_obj
,
"completion_tokens"
,
None
)
...
...
@@ -1632,7 +1633,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
asyncio
.
create_task
(
asyncio
.
to_thread
(
background_task
))
def
_handle_abort_req
(
self
,
recv_obj
):
def
_handle_abort_req
(
self
,
recv_obj
:
AbortReq
):
if
is_health_check_generate_req
(
recv_obj
):
return
state
=
self
.
rid_to_state
[
recv_obj
.
rid
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment