Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3c699772
"examples/pytorch/graphsage/dist/train_dist.py" did not exist on "0767c5fcc9c738587a29837050038649e9dde40e"
Unverified
Commit
3c699772
authored
Oct 03, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 03, 2025
Browse files
Introduce naming convention in `io_struct` and base sglang io classes. (#10133)
parent
e8100774
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
223 additions
and
189 deletions
+223
-189
python/sglang/srt/entrypoints/grpc_request_manager.py
python/sglang/srt/entrypoints/grpc_request_manager.py
+7
-7
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+3
-5
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+5
-5
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+10
-10
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+126
-102
python/sglang/srt/managers/multi_tokenizer_mixin.py
python/sglang/srt/managers/multi_tokenizer_mixin.py
+17
-17
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+11
-8
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+11
-7
python/sglang/srt/managers/tokenizer_communicator_mixin.py
python/sglang/srt/managers/tokenizer_communicator_mixin.py
+9
-5
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+24
-23
No files found.
python/sglang/srt/entrypoints/grpc_request_manager.py
View file @
3c699772
...
...
@@ -22,8 +22,8 @@ import zmq.asyncio
from
sglang.srt.managers.disagg_service
import
start_disagg_service
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
BatchEmbeddingOut
put
,
BatchTokenIDOut
put
,
HealthCheckOutput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
...
...
@@ -467,9 +467,9 @@ class GrpcRequestManager:
await
self
.
is_pause_cond
.
wait
()
# Handle different output types
if
isinstance
(
recv_obj
,
BatchTokenIDOut
):
if
isinstance
(
recv_obj
,
BatchTokenIDOut
put
):
await
self
.
_handle_batch_output
(
recv_obj
)
elif
isinstance
(
recv_obj
,
BatchEmbeddingOut
):
elif
isinstance
(
recv_obj
,
BatchEmbeddingOut
put
):
await
self
.
_handle_embedding_output
(
recv_obj
)
elif
isinstance
(
recv_obj
,
HealthCheckOutput
):
await
self
.
_handle_health_check_output
(
recv_obj
)
...
...
@@ -498,7 +498,7 @@ class GrpcRequestManager:
def
_convert_logprob_style
(
self
,
state
:
GrpcReqState
,
batch_out
:
BatchTokenIDOut
,
batch_out
:
BatchTokenIDOut
put
,
batch_index
:
int
,
):
"""
...
...
@@ -545,7 +545,7 @@ class GrpcRequestManager:
batch_out
.
output_top_logprobs_idx
[
batch_index
]
)
async
def
_handle_batch_output
(
self
,
batch_out
:
BatchTokenIDOut
):
async
def
_handle_batch_output
(
self
,
batch_out
:
BatchTokenIDOut
put
):
"""Handle batch generation output from scheduler."""
# Process each request in the batch
for
i
,
rid
in
enumerate
(
batch_out
.
rids
):
...
...
@@ -666,7 +666,7 @@ class GrpcRequestManager:
asyncio
.
create_task
(
cleanup
())
async
def
_handle_embedding_output
(
self
,
batch_out
:
BatchEmbeddingOut
):
async
def
_handle_embedding_output
(
self
,
batch_out
:
BatchEmbeddingOut
put
):
"""Handle batch embedding output from scheduler."""
for
i
,
rid
in
enumerate
(
batch_out
.
rids
):
if
rid
not
in
self
.
rid_to_state
:
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
3c699772
...
...
@@ -94,8 +94,8 @@ from sglang.srt.managers.io_struct import (
VertexGenerateReqInput
,
)
from
sglang.srt.managers.multi_tokenizer_mixin
import
(
MultiTokenizerManager
,
MultiTokenizerRouter
,
TokenizerWorker
,
get_main_process_id
,
monkey_patch_uvicorn_multiprocessing
,
read_from_shared_memory
,
...
...
@@ -127,9 +127,7 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
# Store global states
@
dataclasses
.
dataclass
class
_GlobalState
:
tokenizer_manager
:
Union
[
TokenizerManager
,
MultiTokenizerRouter
,
MultiTokenizerManager
]
tokenizer_manager
:
Union
[
TokenizerManager
,
MultiTokenizerRouter
,
TokenizerWorker
]
template_manager
:
TemplateManager
scheduler_info
:
Dict
...
...
@@ -164,7 +162,7 @@ async def init_multi_tokenizer() -> ServerArgs:
)
# Launch multi-tokenizer manager process
tokenizer_manager
=
Multi
Tokenizer
Manag
er
(
server_args
,
port_args
)
tokenizer_manager
=
Tokenizer
Work
er
(
server_args
,
port_args
)
template_manager
=
TemplateManager
()
template_manager
.
initialize_templates
(
tokenizer_manager
=
tokenizer_manager
,
...
...
python/sglang/srt/lora/lora_manager.py
View file @
3c699772
...
...
@@ -35,7 +35,7 @@ from sglang.srt.lora.utils import (
get_normalized_target_modules
,
get_target_module_name
,
)
from
sglang.srt.managers.io_struct
import
LoRAUpdate
Resul
t
from
sglang.srt.managers.io_struct
import
LoRAUpdate
Outpu
t
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
replace_submodule
...
...
@@ -107,8 +107,8 @@ class LoRAManager:
def
create_lora_update_result
(
self
,
success
:
bool
,
error_message
:
str
=
""
)
->
LoRAUpdate
Resul
t
:
return
LoRAUpdate
Resul
t
(
)
->
LoRAUpdate
Outpu
t
:
return
LoRAUpdate
Outpu
t
(
success
=
success
,
error_message
=
error_message
,
loaded_adapters
=
{
...
...
@@ -117,7 +117,7 @@ class LoRAManager:
},
)
def
load_lora_adapter
(
self
,
lora_ref
:
LoRARef
)
->
LoRAUpdate
Resul
t
:
def
load_lora_adapter
(
self
,
lora_ref
:
LoRARef
)
->
LoRAUpdate
Outpu
t
:
"""
Load a single LoRA adapter from the specified path.
...
...
@@ -174,7 +174,7 @@ class LoRAManager:
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
)
def
unload_lora_adapter
(
self
,
lora_ref
:
LoRARef
)
->
LoRAUpdate
Resul
t
:
def
unload_lora_adapter
(
self
,
lora_ref
:
LoRARef
)
->
LoRAUpdate
Outpu
t
:
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules.
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
3c699772
...
...
@@ -26,11 +26,11 @@ import zmq
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.io_struct
import
(
BatchEmbeddingOut
,
BatchEmbeddingOut
put
,
BatchMultimodalDecodeReq
,
BatchMultimodalOut
,
BatchStrOut
,
BatchTokenIDOut
,
BatchMultimodalOut
put
,
BatchStrOut
put
,
BatchTokenIDOut
put
,
FreezeGCReq
,
MultiTokenizerRegisterReq
,
)
...
...
@@ -101,8 +101,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
self
.
_request_dispatcher
=
TypeBasedDispatcher
(
[
(
BatchEmbeddingOut
,
self
.
handle_batch_embedding_out
),
(
BatchTokenIDOut
,
self
.
handle_batch_token_id_out
),
(
BatchEmbeddingOut
put
,
self
.
handle_batch_embedding_out
),
(
BatchTokenIDOut
put
,
self
.
handle_batch_token_id_out
),
(
BatchMultimodalDecodeReq
,
self
.
handle_multimodal_decode_req
),
(
MultiTokenizerRegisterReq
,
lambda
x
:
x
),
(
FreezeGCReq
,
self
.
handle_freeze_gc_req
),
...
...
@@ -145,11 +145,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
return
output
[:
-
1
]
return
output
def
handle_batch_embedding_out
(
self
,
recv_obj
:
BatchEmbeddingOut
):
def
handle_batch_embedding_out
(
self
,
recv_obj
:
BatchEmbeddingOut
put
):
# If it is embedding model, no detokenization is needed.
return
recv_obj
def
handle_batch_token_id_out
(
self
,
recv_obj
:
BatchTokenIDOut
):
def
handle_batch_token_id_out
(
self
,
recv_obj
:
BatchTokenIDOut
put
):
bs
=
len
(
recv_obj
.
rids
)
# Initialize decode status
...
...
@@ -224,7 +224,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
s
.
sent_offset
=
len
(
output_str
)
output_strs
.
append
(
incremental_output
)
return
BatchStrOut
(
return
BatchStrOut
put
(
rids
=
recv_obj
.
rids
,
finished_reasons
=
recv_obj
.
finished_reasons
,
output_strs
=
output_strs
,
...
...
@@ -252,7 +252,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
def
handle_multimodal_decode_req
(
self
,
recv_obj
:
BatchMultimodalDecodeReq
):
outputs
=
self
.
tokenizer
.
detokenize
(
recv_obj
)
return
BatchMultimodalOut
(
return
BatchMultimodalOut
put
(
rids
=
recv_obj
.
rids
,
finished_reasons
=
recv_obj
.
finished_reasons
,
outputs
=
outputs
,
...
...
python/sglang/srt/managers/io_struct.py
View file @
3c699772
...
...
@@ -18,6 +18,7 @@ processes (TokenizerManager, DetokenizerManager, Scheduler).
import
copy
import
uuid
from
abc
import
ABC
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
...
...
@@ -36,10 +37,32 @@ else:
# Parameters for a session
@
dataclass
class
BaseReq
(
ABC
):
rid
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
field
(
default
=
None
,
kw_only
=
True
)
def
regenerate_rid
(
self
):
"""Generate a new request ID and return it."""
if
isinstance
(
self
.
rid
,
list
):
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
len
(
self
.
rid
))]
else
:
self
.
rid
=
uuid
.
uuid4
().
hex
return
self
.
rid
@
dataclass
class
BaseBatchReq
(
ABC
):
rids
:
Optional
[
List
[
str
]]
=
field
(
default
=
None
,
kw_only
=
True
)
def
regenerate_rids
(
self
):
"""Generate new request IDs and return them."""
self
.
rids
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
len
(
self
.
rids
))]
return
self
.
rids
@
dataclass
class
SessionParams
:
id
:
Optional
[
str
]
=
None
rid
:
Optional
[
str
]
=
None
offset
:
Optional
[
int
]
=
None
replace
:
Optional
[
bool
]
=
None
drop_previous_output
:
Optional
[
bool
]
=
None
...
...
@@ -63,7 +86,7 @@ MultimodalDataInputFormat = Union[
@
dataclass
class
GenerateReqInput
:
class
GenerateReqInput
(
BaseReq
)
:
# The input prompt. It can be a single prompt or a batch of prompts.
text
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The token ids for text; one can specify either text or input_ids
...
...
@@ -83,8 +106,6 @@ class GenerateReqInput:
audio_data
:
Optional
[
MultimodalDataInputFormat
]
=
None
# The sampling_params. See descriptions below.
sampling_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
# The request id.
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Whether to return logprobs.
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
# If return logprobs, the start location in the prompt for returning logprobs.
...
...
@@ -491,11 +512,6 @@ class GenerateReqInput:
):
raise
ValueError
(
"Session params must be a dict or a list of dicts."
)
def
regenerate_rid
(
self
):
"""Generate a new request ID and return it."""
self
.
rid
=
uuid
.
uuid4
().
hex
return
self
.
rid
def
__getitem__
(
self
,
i
):
return
GenerateReqInput
(
text
=
self
.
text
[
i
]
if
self
.
text
is
not
None
else
None
,
...
...
@@ -558,9 +574,7 @@ class GenerateReqInput:
@
dataclass
class
TokenizedGenerateReqInput
:
# The request id
rid
:
str
class
TokenizedGenerateReqInput
(
BaseReq
):
# The input text
input_text
:
str
# The input token ids
...
...
@@ -625,7 +639,7 @@ class TokenizedGenerateReqInput:
@
dataclass
class
BatchTokenizedGenerateReqInput
:
class
BatchTokenizedGenerateReqInput
(
BaseBatchReq
)
:
# The batch of tokenized requests
batch
:
List
[
TokenizedGenerateReqInput
]
...
...
@@ -640,7 +654,7 @@ class BatchTokenizedGenerateReqInput:
@
dataclass
class
EmbeddingReqInput
:
class
EmbeddingReqInput
(
BaseReq
)
:
# The input prompt. It can be a single prompt or a batch of prompts.
text
:
Optional
[
Union
[
List
[
List
[
str
]],
List
[
str
],
str
]]
=
None
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
...
...
@@ -656,8 +670,6 @@ class EmbeddingReqInput:
audio_data
:
Optional
[
MultimodalDataInputFormat
]
=
None
# The token ids for text; one can either specify text or input_ids.
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
# The request id.
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Dummy sampling params for compatibility
sampling_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
# Dummy input embeds for compatibility
...
...
@@ -728,10 +740,6 @@ class EmbeddingReqInput:
for
i
in
range
(
self
.
batch_size
):
self
.
sampling_params
[
i
][
"max_new_tokens"
]
=
0
def
regenerate_rid
(
self
):
self
.
rid
=
uuid
.
uuid4
().
hex
return
self
.
rid
def
contains_mm_input
(
self
)
->
bool
:
return
(
has_valid_data
(
self
.
image_data
)
...
...
@@ -760,9 +768,7 @@ class EmbeddingReqInput:
@
dataclass
class
TokenizedEmbeddingReqInput
:
# The request id
rid
:
str
class
TokenizedEmbeddingReqInput
(
BaseReq
):
# The input text
input_text
:
str
# The input token ids
...
...
@@ -780,7 +786,7 @@ class TokenizedEmbeddingReqInput:
@
dataclass
class
BatchTokenizedEmbeddingReqInput
:
class
BatchTokenizedEmbeddingReqInput
(
BaseBatchReq
)
:
# The batch of tokenized embedding requests
batch
:
List
[
TokenizedEmbeddingReqInput
]
...
...
@@ -795,9 +801,7 @@ class BatchTokenizedEmbeddingReqInput:
@
dataclass
class
BatchTokenIDOut
:
# The request id
rids
:
List
[
str
]
class
BatchTokenIDOutput
(
BaseBatchReq
):
# The finish reason
finished_reasons
:
List
[
BaseFinishReason
]
# For incremental decoding
...
...
@@ -842,7 +846,7 @@ class BatchTokenIDOut:
@
dataclass
class
BatchMultimodalDecodeReq
:
class
BatchMultimodalDecodeReq
(
BaseBatchReq
)
:
decoded_ids
:
List
[
int
]
input_token_logprobs_val
:
List
[
float
]
input_token_logprobs_idx
:
List
[
int
]
...
...
@@ -854,8 +858,6 @@ class BatchMultimodalDecodeReq:
image_resolutions
:
List
[
List
[
int
]]
resize_image_resolutions
:
List
[
List
[
int
]]
# The request id
rids
:
List
[
str
]
finished_reasons
:
List
[
BaseFinishReason
]
# Token counts
...
...
@@ -871,9 +873,7 @@ class BatchMultimodalDecodeReq:
@
dataclass
class
BatchStrOut
:
# The request id
rids
:
List
[
str
]
class
BatchStrOutput
(
BaseBatchReq
):
# The finish reason
finished_reasons
:
List
[
dict
]
# The output decoded strings
...
...
@@ -909,9 +909,7 @@ class BatchStrOut:
@
dataclass
class
BatchMultimodalOut
:
# The request id
rids
:
List
[
str
]
class
BatchMultimodalOutput
(
BaseBatchReq
):
# The finish reason
finished_reasons
:
List
[
dict
]
decoded_ids
:
List
[
List
[
int
]]
...
...
@@ -936,9 +934,7 @@ class BatchMultimodalOut:
@
dataclass
class
BatchEmbeddingOut
:
# The request id
rids
:
List
[
str
]
class
BatchEmbeddingOutput
(
BaseBatchReq
):
# The finish reason
finished_reasons
:
List
[
BaseFinishReason
]
# The output embedding
...
...
@@ -952,27 +948,27 @@ class BatchEmbeddingOut:
@
dataclass
class
ClearHiCacheReqInput
:
class
ClearHiCacheReqInput
(
BaseReq
)
:
pass
@
dataclass
class
ClearHiCacheReqOutput
:
class
ClearHiCacheReqOutput
(
BaseReq
)
:
success
:
bool
@
dataclass
class
FlushCacheReqInput
:
class
FlushCacheReqInput
(
BaseReq
)
:
pass
@
dataclass
class
FlushCacheReqOutput
:
class
FlushCacheReqOutput
(
BaseReq
)
:
success
:
bool
@
dataclass
class
UpdateWeightFromDiskReqInput
:
class
UpdateWeightFromDiskReqInput
(
BaseReq
)
:
# The model path with the new weights
model_path
:
str
# The format to load the weights
...
...
@@ -990,7 +986,7 @@ class UpdateWeightFromDiskReqInput:
@
dataclass
class
UpdateWeightFromDiskReqOutput
:
class
UpdateWeightFromDiskReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
# Number of paused requests during weight sync.
...
...
@@ -998,7 +994,7 @@ class UpdateWeightFromDiskReqOutput:
@
dataclass
class
UpdateWeightsFromDistributedReqInput
:
class
UpdateWeightsFromDistributedReqInput
(
BaseReq
)
:
names
:
List
[
str
]
dtypes
:
List
[
str
]
shapes
:
List
[
List
[
int
]]
...
...
@@ -1013,13 +1009,13 @@ class UpdateWeightsFromDistributedReqInput:
@
dataclass
class
UpdateWeightsFromDistributedReqOutput
:
class
UpdateWeightsFromDistributedReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
@
dataclass
class
UpdateWeightsFromTensorReqInput
:
class
UpdateWeightsFromTensorReqInput
(
BaseReq
)
:
"""Update model weights from tensor input.
- Tensors are serialized for transmission
...
...
@@ -1038,13 +1034,13 @@ class UpdateWeightsFromTensorReqInput:
@
dataclass
class
UpdateWeightsFromTensorReqOutput
:
class
UpdateWeightsFromTensorReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
@
dataclass
class
InitWeightsSendGroupForRemoteInstanceReqInput
:
class
InitWeightsSendGroupForRemoteInstanceReqInput
(
BaseReq
)
:
# The master address
master_address
:
str
# The ports for each rank's communication group
...
...
@@ -1060,13 +1056,13 @@ class InitWeightsSendGroupForRemoteInstanceReqInput:
@
dataclass
class
InitWeightsSendGroupForRemoteInstanceReqOutput
:
class
InitWeightsSendGroupForRemoteInstanceReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
@
dataclass
class
SendWeightsToRemoteInstanceReqInput
:
class
SendWeightsToRemoteInstanceReqInput
(
BaseReq
)
:
# The master address
master_address
:
str
# The ports for each rank's communication group
...
...
@@ -1076,13 +1072,13 @@ class SendWeightsToRemoteInstanceReqInput:
@
dataclass
class
SendWeightsToRemoteInstanceReqOutput
:
class
SendWeightsToRemoteInstanceReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
@
dataclass
class
InitWeightsUpdateGroupReqInput
:
class
InitWeightsUpdateGroupReqInput
(
BaseReq
)
:
# The master address
master_address
:
str
# The master port
...
...
@@ -1098,24 +1094,24 @@ class InitWeightsUpdateGroupReqInput:
@
dataclass
class
InitWeightsUpdateGroupReqOutput
:
class
InitWeightsUpdateGroupReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
@
dataclass
class
DestroyWeightsUpdateGroupReqInput
:
class
DestroyWeightsUpdateGroupReqInput
(
BaseReq
)
:
group_name
:
str
=
"weight_update_group"
@
dataclass
class
DestroyWeightsUpdateGroupReqOutput
:
class
DestroyWeightsUpdateGroupReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
@
dataclass
class
UpdateWeightVersionReqInput
:
class
UpdateWeightVersionReqInput
(
BaseReq
)
:
# The new weight version
new_version
:
str
# Whether to abort all running requests before updating
...
...
@@ -1123,89 +1119,87 @@ class UpdateWeightVersionReqInput:
@
dataclass
class
GetWeightsByNameReqInput
:
class
GetWeightsByNameReqInput
(
BaseReq
)
:
name
:
str
truncate_size
:
int
=
100
@
dataclass
class
GetWeightsByNameReqOutput
:
class
GetWeightsByNameReqOutput
(
BaseReq
)
:
parameter
:
list
@
dataclass
class
ReleaseMemoryOccupationReqInput
:
class
ReleaseMemoryOccupationReqInput
(
BaseReq
)
:
# Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
tags
:
Optional
[
List
[
str
]]
=
None
@
dataclass
class
ReleaseMemoryOccupationReqOutput
:
class
ReleaseMemoryOccupationReqOutput
(
BaseReq
)
:
pass
@
dataclass
class
ResumeMemoryOccupationReqInput
:
class
ResumeMemoryOccupationReqInput
(
BaseReq
)
:
# Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
tags
:
Optional
[
List
[
str
]]
=
None
@
dataclass
class
ResumeMemoryOccupationReqOutput
:
class
ResumeMemoryOccupationReqOutput
(
BaseReq
)
:
pass
@
dataclass
class
SlowDownReqInput
:
class
SlowDownReqInput
(
BaseReq
)
:
forward_sleep_time
:
Optional
[
float
]
@
dataclass
class
SlowDownReqOutput
:
class
SlowDownReqOutput
(
BaseReq
)
:
pass
@
dataclass
class
AbortReq
:
# The request id
rid
:
str
=
""
class
AbortReq
(
BaseReq
):
# Whether to abort all requests
abort_all
:
bool
=
False
# The finished reason data
finished_reason
:
Optional
[
Dict
[
str
,
Any
]]
=
None
abort_reason
:
Optional
[
str
]
=
None
# used in MultiTokenzierManager mode
rids
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
def
__post_init__
(
self
):
self
.
rids
=
self
.
rid
# FIXME: This is a hack to keep the same with the old code
if
self
.
rid
is
None
:
self
.
rid
=
""
@
dataclass
class
GetInternalStateReq
:
class
GetInternalStateReq
(
BaseReq
)
:
pass
@
dataclass
class
GetInternalStateReqOutput
:
class
GetInternalStateReqOutput
(
BaseReq
)
:
internal_state
:
Dict
[
Any
,
Any
]
@
dataclass
class
SetInternalStateReq
:
class
SetInternalStateReq
(
BaseReq
)
:
server_args
:
Dict
[
str
,
Any
]
@
dataclass
class
SetInternalStateReqOutput
:
class
SetInternalStateReqOutput
(
BaseReq
)
:
updated
:
bool
server_args
:
Dict
[
str
,
Any
]
@
dataclass
class
ProfileReqInput
:
class
ProfileReqInput
(
BaseReq
)
:
# The output directory
output_dir
:
Optional
[
str
]
=
None
# If set, it profile as many as this number of steps.
...
...
@@ -1225,7 +1219,7 @@ class ProfileReqType(Enum):
@
dataclass
class
ProfileReq
:
class
ProfileReq
(
BaseReq
)
:
type
:
ProfileReqType
output_dir
:
Optional
[
str
]
=
None
start_step
:
Optional
[
int
]
=
None
...
...
@@ -1238,18 +1232,18 @@ class ProfileReq:
@
dataclass
class
ProfileReqOutput
:
class
ProfileReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
@
dataclass
class
FreezeGCReq
:
class
FreezeGCReq
(
BaseReq
)
:
pass
@
dataclass
class
ConfigureLoggingReq
:
class
ConfigureLoggingReq
(
BaseReq
)
:
log_requests
:
Optional
[
bool
]
=
None
log_requests_level
:
Optional
[
int
]
=
None
dump_requests_folder
:
Optional
[
str
]
=
None
...
...
@@ -1258,35 +1252,39 @@ class ConfigureLoggingReq:
@
dataclass
class
OpenSessionReqInput
:
class
OpenSessionReqInput
(
BaseReq
)
:
capacity_of_str_len
:
int
session_id
:
Optional
[
str
]
=
None
@
dataclass
class
CloseSessionReqInput
:
class
CloseSessionReqInput
(
BaseReq
)
:
session_id
:
str
@
dataclass
class
OpenSessionReqOutput
:
class
OpenSessionReqOutput
(
BaseReq
)
:
session_id
:
Optional
[
str
]
success
:
bool
@
dataclass
class
HealthCheckOutput
:
class
HealthCheckOutput
(
BaseReq
)
:
pass
class
ExpertDistributionReq
(
Enum
):
class
ExpertDistributionReq
Type
(
Enum
):
START_RECORD
=
1
STOP_RECORD
=
2
DUMP_RECORD
=
3
class
ExpertDistributionReq
(
BaseReq
):
action
:
ExpertDistributionReqType
@
dataclass
class
ExpertDistributionReqOutput
:
class
ExpertDistributionReqOutput
(
BaseReq
)
:
pass
...
...
@@ -1304,7 +1302,7 @@ class Tool:
@
dataclass
class
ParseFunctionCallReq
:
class
ParseFunctionCallReq
(
BaseReq
)
:
text
:
str
# The text to parse.
tools
:
List
[
Tool
]
=
field
(
default_factory
=
list
...
...
@@ -1315,31 +1313,31 @@ class ParseFunctionCallReq:
@
dataclass
class
SeparateReasoningReqInput
:
class
SeparateReasoningReqInput
(
BaseReq
)
:
text
:
str
# The text to parse.
reasoning_parser
:
str
# Specify the parser type, e.g., "deepseek-r1".
@
dataclass
class
VertexGenerateReqInput
:
class
VertexGenerateReqInput
(
BaseReq
)
:
instances
:
List
[
dict
]
parameters
:
Optional
[
dict
]
=
None
@
dataclass
class
RpcReqInput
:
class
RpcReqInput
(
BaseReq
)
:
method
:
str
parameters
:
Optional
[
Dict
]
=
None
@
dataclass
class
RpcReqOutput
:
class
RpcReqOutput
(
BaseReq
)
:
success
:
bool
message
:
str
@
dataclass
class
LoadLoRAAdapterReqInput
:
class
LoadLoRAAdapterReqInput
(
BaseReq
)
:
# The name of the lora module to newly loaded.
lora_name
:
str
# The path of loading.
...
...
@@ -1359,7 +1357,7 @@ class LoadLoRAAdapterReqInput:
@
dataclass
class
UnloadLoRAAdapterReqInput
:
class
UnloadLoRAAdapterReqInput
(
BaseReq
)
:
# The name of lora module to unload.
lora_name
:
str
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
...
...
@@ -1373,23 +1371,23 @@ class UnloadLoRAAdapterReqInput:
@
dataclass
class
LoRAUpdate
Result
:
class
LoRAUpdate
Output
(
BaseReq
)
:
success
:
bool
error_message
:
Optional
[
str
]
=
None
loaded_adapters
:
Optional
[
Dict
[
str
,
LoRARef
]]
=
None
LoadLoRAAdapterReqOutput
=
UnloadLoRAAdapterReqOutput
=
LoRAUpdate
Resul
t
LoadLoRAAdapterReqOutput
=
UnloadLoRAAdapterReqOutput
=
LoRAUpdate
Outpu
t
@
dataclass
class
MultiTokenizerRegisterReq
:
rids
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
class
MultiTokenizerRegisterReq
(
BaseBatchReq
):
ipc_name
:
Optional
[
str
]
=
None
@
dataclass
class
MultiTokenizerWrapper
:
# FIXME(lsyin): remove this
worker_id
:
int
obj
:
Optional
[
Any
]
=
None
...
...
@@ -1400,17 +1398,17 @@ class BlockReqType(Enum):
@
dataclass
class
BlockReqInput
:
class
BlockReqInput
(
BaseReq
)
:
type
:
BlockReqType
@
dataclass
class
GetLoadReqInput
:
class
GetLoadReqInput
(
BaseReq
)
:
pass
@
dataclass
class
GetLoadReqOutput
:
class
GetLoadReqOutput
(
BaseReq
)
:
dp_rank
:
int
num_reqs
:
int
num_waiting_reqs
:
int
...
...
@@ -1418,5 +1416,31 @@ class GetLoadReqOutput:
@
dataclass
class
WatchLoadUpdateReq
:
class
WatchLoadUpdateReq
(
BaseReq
)
:
loads
:
List
[
GetLoadReqOutput
]
def
_check_all_req_types
():
"""A helper function to check all request types are defined in this file."""
import
inspect
import
sys
all_classes
=
inspect
.
getmembers
(
sys
.
modules
[
__name__
],
inspect
.
isclass
)
for
class_type
in
all_classes
:
# check its name
name
=
class_type
[
0
]
is_io_struct
=
(
name
.
endswith
(
"Req"
)
or
name
.
endswith
(
"Input"
)
or
name
.
endswith
(
"Output"
)
)
is_base_req
=
issubclass
(
class_type
[
1
],
BaseReq
)
or
issubclass
(
class_type
[
1
],
BaseBatchReq
)
if
is_io_struct
and
not
is_base_req
:
raise
ValueError
(
f
"
{
name
}
is not a subclass of BaseReq or BaseBatchReq."
)
if
is_base_req
and
not
is_io_struct
:
raise
ValueError
(
f
"
{
name
}
is a subclass of BaseReq but not follow the naming convention."
)
_check_all_req_types
()
python/sglang/srt/managers/multi_tokenizer_mixin.py
View file @
3c699772
...
...
@@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""M
ultiTokenizerMixin is a class that provides nesscary method
s for
M
ulti
TokenizerManager and DetokenizerManager.
"""
"""M
ixin class and util
s for
m
ulti
-http-worker mode
"""
import
asyncio
import
logging
import
multiprocessing
as
multiprocessing
...
...
@@ -30,10 +30,10 @@ import zmq.asyncio
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
,
TransferBackend
from
sglang.srt.managers.disagg_service
import
start_disagg_service
from
sglang.srt.managers.io_struct
import
(
BatchEmbeddingOut
,
BatchMultimodalOut
,
BatchStrOut
,
BatchTokenIDOut
,
BatchEmbeddingOut
put
,
BatchMultimodalOut
put
,
BatchStrOut
put
,
BatchTokenIDOut
put
,
MultiTokenizerRegisterReq
,
MultiTokenizerWrapper
,
)
...
...
@@ -83,8 +83,8 @@ class SocketMapping:
def
_handle_output_by_index
(
output
,
i
):
"""NOTE: A maintainable method is better here."""
if
isinstance
(
output
,
BatchTokenIDOut
):
new_output
=
BatchTokenIDOut
(
if
isinstance
(
output
,
BatchTokenIDOut
put
):
new_output
=
BatchTokenIDOut
put
(
rids
=
[
output
.
rids
[
i
]],
finished_reasons
=
(
[
output
.
finished_reasons
[
i
]]
...
...
@@ -198,8 +198,8 @@ def _handle_output_by_index(output, i):
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
)
elif
isinstance
(
output
,
BatchEmbeddingOut
):
new_output
=
BatchEmbeddingOut
(
elif
isinstance
(
output
,
BatchEmbeddingOut
put
):
new_output
=
BatchEmbeddingOut
put
(
rids
=
[
output
.
rids
[
i
]],
finished_reasons
=
(
[
output
.
finished_reasons
[
i
]]
...
...
@@ -216,8 +216,8 @@ def _handle_output_by_index(output, i):
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
)
elif
isinstance
(
output
,
BatchStrOut
):
new_output
=
BatchStrOut
(
elif
isinstance
(
output
,
BatchStrOut
put
):
new_output
=
BatchStrOut
put
(
rids
=
[
output
.
rids
[
i
]],
finished_reasons
=
(
[
output
.
finished_reasons
[
i
]]
...
...
@@ -314,8 +314,8 @@ def _handle_output_by_index(output, i):
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
)
elif
isinstance
(
output
,
BatchMultimodalOut
):
new_output
=
BatchMultimodalOut
(
elif
isinstance
(
output
,
BatchMultimodalOut
put
):
new_output
=
BatchMultimodalOut
put
(
rids
=
[
output
.
rids
[
i
]],
finished_reasons
=
(
[
output
.
finished_reasons
[
i
]]
...
...
@@ -343,7 +343,7 @@ def _handle_output_by_index(output, i):
class
MultiHttpWorkerDetokenizerMixin
:
"""Mixin class for
MultiTokenizerManager and
DetokenizerManager"""
"""Mixin class for DetokenizerManager"""
def
get_worker_ids_from_req_rids
(
self
,
rids
):
if
isinstance
(
rids
,
list
):
...
...
@@ -386,7 +386,7 @@ class MultiHttpWorkerDetokenizerMixin:
class
MultiTokenizerRouter
:
"""A router to receive requests from
Multi
Tokenizer
Manag
er"""
"""A router to receive requests from Tokenizer
Work
er"""
def
__init__
(
self
,
...
...
@@ -454,8 +454,8 @@ class MultiTokenizerRouter:
self
.
socket_mapping
.
send_output
(
worker_id
,
new_recv_obj
)
class
Multi
Tokenizer
Manag
er
(
TokenizerManager
):
"""
Multi Process Tokenizer Manager that tokenizes the text.
"""
class
Tokenizer
Work
er
(
TokenizerManager
):
"""
Tokenizer Worker in multi-http-worker mode
"""
def
__init__
(
self
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
3c699772
...
...
@@ -78,6 +78,7 @@ from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqInput
,
ExpertDistributionReq
,
ExpertDistributionReqOutput
,
ExpertDistributionReqType
,
FlushCacheReqInput
,
FlushCacheReqOutput
,
FreezeGCReq
,
...
...
@@ -1487,12 +1488,12 @@ class Scheduler(
req
.
priority
=
-
sys
.
maxsize
-
1
elif
not
self
.
enable_priority_scheduling
and
req
.
priority
is
not
None
:
abort_req
=
AbortReq
(
req
.
rid
,
finished_reason
=
{
"type"
:
"abort"
,
"status_code"
:
HTTPStatus
.
SERVICE_UNAVAILABLE
,
"message"
:
"Using priority is disabled for this server. Please send a new request without a priority."
,
},
rid
=
req
.
rid
,
)
self
.
send_to_tokenizer
.
send_pyobj
(
abort_req
)
...
...
@@ -1528,12 +1529,12 @@ class Scheduler(
self
.
send_to_tokenizer
.
send_pyobj
(
AbortReq
(
req_to_abort
.
rid
,
finished_reason
=
{
"type"
:
"abort"
,
"status_code"
:
HTTPStatus
.
SERVICE_UNAVAILABLE
,
"message"
:
message
,
},
rid
=
req_to_abort
.
rid
,
)
)
return
req_to_abort
.
rid
==
recv_req
.
rid
...
...
@@ -2005,7 +2006,7 @@ class Scheduler(
self
.
new_token_ratio
=
new_token_ratio
for
req
in
reqs_to_abort
:
self
.
send_to_tokenizer
.
send_pyobj
(
AbortReq
(
req
.
rid
,
abort_reason
=
req
.
to_abort_message
)
AbortReq
(
abort_reason
=
req
.
to_abort_message
,
rid
=
req
.
rid
)
)
logger
.
info
(
...
...
@@ -2575,7 +2576,7 @@ class Scheduler(
if
self
.
enable_hicache_storage
:
# to release prefetch events associated with the request
self
.
tree_cache
.
release_aborted_request
(
req
.
rid
)
self
.
send_to_tokenizer
.
send_pyobj
(
AbortReq
(
req
.
rid
))
self
.
send_to_tokenizer
.
send_pyobj
(
AbortReq
(
rid
=
req
.
rid
))
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
tree_cache
.
cache_finished_req
(
req
)
...
...
@@ -2687,11 +2688,12 @@ class Scheduler(
return
SlowDownReqOutput
()
def
expert_distribution_handle
(
self
,
recv_req
:
ExpertDistributionReq
):
if
recv_req
==
ExpertDistributionReq
.
START_RECORD
:
action
=
recv_req
.
action
if
action
==
ExpertDistributionReqType
.
START_RECORD
:
get_global_expert_distribution_recorder
().
start_record
()
elif
recv_req
==
ExpertDistributionReq
.
STOP_RECORD
:
elif
action
==
ExpertDistributionReq
Type
.
STOP_RECORD
:
get_global_expert_distribution_recorder
().
stop_record
()
elif
recv_req
==
ExpertDistributionReq
.
DUMP_RECORD
:
elif
action
==
ExpertDistributionReq
Type
.
DUMP_RECORD
:
get_global_expert_distribution_recorder
().
dump_record
()
else
:
raise
ValueError
(
f
"Unrecognized ExpertDistributionReq value:
{
recv_req
=
}
"
)
...
...
@@ -2774,7 +2776,8 @@ class IdleSleeper:
def
is_health_check_generate_req
(
recv_req
):
return
getattr
(
recv_req
,
"rid"
,
""
).
startswith
(
"HEALTH_CHECK"
)
rid
=
getattr
(
recv_req
,
"rid"
,
None
)
return
rid
is
not
None
and
rid
.
startswith
(
"HEALTH_CHECK"
)
def
is_work_request
(
recv_req
):
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
3c699772
...
...
@@ -9,7 +9,11 @@ import torch
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
AbortReq
,
BatchEmbeddingOut
,
BatchTokenIDOut
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOutput
,
BatchTokenIDOutput
,
)
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
,
Req
,
ScheduleBatch
if
TYPE_CHECKING
:
...
...
@@ -140,7 +144,7 @@ class SchedulerOutputProcessorMixin:
logger
.
error
(
f
"Grammar accept_token failed for req
{
req
.
rid
}
with token
{
next_token_id
}
:
{
e
}
"
)
self
.
abort_request
(
AbortReq
(
req
.
rid
))
self
.
abort_request
(
AbortReq
(
rid
=
req
.
rid
))
req
.
grammar
.
finished
=
req
.
finished
()
else
:
# being chunked reqs' prefill is not finished
...
...
@@ -292,7 +296,7 @@ class SchedulerOutputProcessorMixin:
logger
.
error
(
f
"Grammar accept_token failed for req
{
req
.
rid
}
with token
{
next_token_id
}
:
{
e
}
"
)
self
.
abort_request
(
AbortReq
(
req
.
rid
))
self
.
abort_request
(
AbortReq
(
rid
=
req
.
rid
))
req
.
grammar
.
finished
=
req
.
finished
()
self
.
set_next_batch_sampling_info_done
(
batch
)
...
...
@@ -714,8 +718,7 @@ class SchedulerOutputProcessorMixin:
return
self
.
send_to_detokenizer
.
send_pyobj
(
BatchTokenIDOut
(
rids
,
BatchTokenIDOutput
(
finished_reasons
,
decoded_texts
,
decode_ids_list
,
...
...
@@ -741,6 +744,7 @@ class SchedulerOutputProcessorMixin:
output_token_ids_logprobs_val
,
output_token_ids_logprobs_idx
,
output_hidden_states
,
rids
=
rids
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
)
...
...
@@ -761,12 +765,12 @@ class SchedulerOutputProcessorMixin:
prompt_tokens
.
append
(
len
(
req
.
origin_input_ids
))
cached_tokens
.
append
(
req
.
cached_tokens
)
self
.
send_to_detokenizer
.
send_pyobj
(
BatchEmbeddingOut
(
rids
,
BatchEmbeddingOutput
(
finished_reasons
,
embeddings
,
prompt_tokens
,
cached_tokens
,
rids
=
rids
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
)
...
...
python/sglang/srt/managers/tokenizer_communicator_mixin.py
View file @
3c699772
...
...
@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqOutput
,
ExpertDistributionReq
,
ExpertDistributionReqOutput
,
ExpertDistributionReqType
,
FlushCacheReqInput
,
FlushCacheReqOutput
,
GetInternalStateReq
,
...
...
@@ -44,7 +45,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqOutput
,
LoadLoRAAdapterReqInput
,
LoadLoRAAdapterReqOutput
,
LoRAUpdate
Resul
t
,
LoRAUpdate
Outpu
t
,
MultiTokenizerWrapper
,
OpenSessionReqInput
,
ProfileReq
,
...
...
@@ -276,7 +277,7 @@ class TokenizerCommunicatorMixin:
self
.
expert_distribution_communicator
.
handle_recv
,
),
(
LoRAUpdate
Resul
t
,
LoRAUpdate
Outpu
t
,
self
.
update_lora_adapter_communicator
.
handle_recv
,
),
(
...
...
@@ -335,15 +336,18 @@ class TokenizerCommunicatorMixin:
async
def
start_expert_distribution_record
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
START_RECORD
)
req
=
ExpertDistributionReq
(
action
=
ExpertDistributionReqType
.
START_RECORD
)
await
self
.
expert_distribution_communicator
(
req
)
async
def
stop_expert_distribution_record
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
STOP_RECORD
)
req
=
ExpertDistributionReq
(
action
=
ExpertDistributionReqType
.
STOP_RECORD
)
await
self
.
expert_distribution_communicator
(
req
)
async
def
dump_expert_distribution_record
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
DUMP_RECORD
)
req
=
ExpertDistributionReq
(
action
=
ExpertDistributionReqType
.
DUMP_RECORD
)
await
self
.
expert_distribution_communicator
(
req
)
async
def
init_weights_update_group
(
self
:
TokenizerManager
,
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
3c699772
...
...
@@ -48,18 +48,17 @@ from sglang.srt.hf_transformers_utils import (
get_tokenizer
,
get_tokenizer_from_processor
,
)
from
sglang.srt.lora.lora_registry
import
LoRARef
,
LoRARegistry
from
sglang.srt.lora.lora_registry
import
LoRARegistry
from
sglang.srt.managers.async_dynamic_batch_tokenizer
import
AsyncDynamicbatchTokenizer
from
sglang.srt.managers.disagg_service
import
start_disagg_service
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOut
,
BatchMultimodalOut
,
BatchStrOut
,
BatchTokenIDOut
,
BatchEmbeddingOut
put
,
BatchMultimodalOut
put
,
BatchStrOut
put
,
BatchTokenIDOut
put
,
BatchTokenizedEmbeddingReqInput
,
BatchTokenizedGenerateReqInput
,
CloseSessionReqInput
,
ConfigureLoggingReq
,
EmbeddingReqInput
,
FreezeGCReq
,
...
...
@@ -67,7 +66,6 @@ from sglang.srt.managers.io_struct import (
GetLoadReqInput
,
HealthCheckOutput
,
MultiTokenizerWrapper
,
OpenSessionReqInput
,
OpenSessionReqOutput
,
SessionParams
,
TokenizedEmbeddingReqInput
,
...
...
@@ -341,10 +339,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
[
(
(
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
BatchMultimodalOut
,
BatchStrOut
put
,
BatchEmbeddingOut
put
,
BatchTokenIDOut
put
,
BatchMultimodalOut
put
,
),
self
.
_handle_batch_output
,
),
...
...
@@ -716,7 +714,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
tokenized_obj
=
TokenizedGenerateReqInput
(
obj
.
rid
,
input_text
,
input_ids
,
mm_inputs
,
...
...
@@ -726,6 +723,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
obj
.
top_logprobs_num
,
obj
.
token_ids_logprob
,
obj
.
stream
,
rid
=
obj
.
rid
,
bootstrap_host
=
obj
.
bootstrap_host
,
bootstrap_port
=
obj
.
bootstrap_port
,
bootstrap_room
=
obj
.
bootstrap_room
,
...
...
@@ -740,12 +738,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
elif
isinstance
(
obj
,
EmbeddingReqInput
):
tokenized_obj
=
TokenizedEmbeddingReqInput
(
obj
.
rid
,
input_text
,
input_ids
,
mm_inputs
,
token_type_ids
,
sampling_params
,
rid
=
obj
.
rid
,
priority
=
obj
.
priority
,
)
...
...
@@ -1038,7 +1036,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
def
abort_request
(
self
,
rid
:
str
=
""
,
abort_all
:
bool
=
False
):
if
not
abort_all
and
rid
not
in
self
.
rid_to_state
:
return
req
=
AbortReq
(
rid
,
abort_all
)
req
=
AbortReq
(
rid
=
rid
,
abort_all
=
abort_all
)
self
.
send_to_scheduler
.
send_pyobj
(
req
)
if
self
.
enable_metrics
:
# TODO: also use custom_labels from the request
...
...
@@ -1303,7 +1301,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
def
_handle_batch_output
(
self
,
recv_obj
:
Union
[
BatchStrOut
,
BatchEmbeddingOut
,
BatchMultimodalOut
,
BatchTokenIDOut
BatchStrOutput
,
BatchEmbeddingOutput
,
BatchMultimodalOutput
,
BatchTokenIDOutput
,
],
):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
...
...
@@ -1337,7 +1338,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
i
,
)
if
not
isinstance
(
recv_obj
,
BatchEmbeddingOut
):
if
not
isinstance
(
recv_obj
,
BatchEmbeddingOut
put
):
meta_info
.
update
(
{
"completion_tokens"
:
recv_obj
.
completion_tokens
[
i
],
...
...
@@ -1348,7 +1349,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if
getattr
(
recv_obj
,
"output_hidden_states"
,
None
):
meta_info
[
"hidden_states"
]
=
recv_obj
.
output_hidden_states
[
i
]
if
isinstance
(
recv_obj
,
BatchStrOut
):
if
isinstance
(
recv_obj
,
BatchStrOut
put
):
state
.
text
+=
recv_obj
.
output_strs
[
i
]
if
state
.
obj
.
stream
:
state
.
output_ids
.
extend
(
recv_obj
.
output_ids
[
i
])
...
...
@@ -1363,7 +1364,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"output_ids"
:
output_token_ids
,
"meta_info"
:
meta_info
,
}
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
):
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
put
):
if
self
.
server_args
.
stream_output
and
state
.
obj
.
stream
:
state
.
output_ids
.
extend
(
recv_obj
.
output_ids
[
i
])
output_token_ids
=
state
.
output_ids
[
state
.
last_output_offset
:]
...
...
@@ -1376,10 +1377,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"output_ids"
:
output_token_ids
,
"meta_info"
:
meta_info
,
}
elif
isinstance
(
recv_obj
,
BatchMultimodalOut
):
elif
isinstance
(
recv_obj
,
BatchMultimodalOut
put
):
raise
NotImplementedError
(
"BatchMultimodalOut not implemented"
)
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
put
)
out_dict
=
{
"embedding"
:
recv_obj
.
embeddings
[
i
],
"meta_info"
:
meta_info
,
...
...
@@ -1418,7 +1419,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
top_logprobs_num
:
int
,
token_ids_logprob
:
List
[
int
],
return_text_in_logprobs
:
bool
,
recv_obj
:
BatchStrOut
,
recv_obj
:
BatchStrOut
put
,
recv_obj_index
:
int
,
):
if
recv_obj
.
input_token_logprobs_val
is
None
:
...
...
@@ -1536,7 +1537,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
ret
.
append
(
None
)
return
ret
def
collect_metrics
(
self
,
state
:
ReqState
,
recv_obj
:
BatchStrOut
,
i
:
int
):
def
collect_metrics
(
self
,
state
:
ReqState
,
recv_obj
:
BatchStrOut
put
,
i
:
int
):
completion_tokens
=
(
recv_obj
.
completion_tokens
[
i
]
if
getattr
(
recv_obj
,
"completion_tokens"
,
None
)
...
...
@@ -1632,7 +1633,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
asyncio
.
create_task
(
asyncio
.
to_thread
(
background_task
))
def
_handle_abort_req
(
self
,
recv_obj
):
def
_handle_abort_req
(
self
,
recv_obj
:
AbortReq
):
if
is_health_check_generate_req
(
recv_obj
):
return
state
=
self
.
rid_to_state
[
recv_obj
.
rid
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment