Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3c699772
Unverified
Commit
3c699772
authored
Oct 03, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 03, 2025
Browse files
Introduce naming convention in `io_struct` and base sglang io classes. (#10133)
parent
e8100774
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
223 additions
and
189 deletions
+223
-189
python/sglang/srt/entrypoints/grpc_request_manager.py
python/sglang/srt/entrypoints/grpc_request_manager.py
+7
-7
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+3
-5
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+5
-5
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+10
-10
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+126
-102
python/sglang/srt/managers/multi_tokenizer_mixin.py
python/sglang/srt/managers/multi_tokenizer_mixin.py
+17
-17
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+11
-8
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+11
-7
python/sglang/srt/managers/tokenizer_communicator_mixin.py
python/sglang/srt/managers/tokenizer_communicator_mixin.py
+9
-5
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+24
-23
No files found.
python/sglang/srt/entrypoints/grpc_request_manager.py
View file @
3c699772
...
@@ -22,8 +22,8 @@ import zmq.asyncio
...
@@ -22,8 +22,8 @@ import zmq.asyncio
from
sglang.srt.managers.disagg_service
import
start_disagg_service
from
sglang.srt.managers.disagg_service
import
start_disagg_service
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
AbortReq
,
BatchEmbeddingOut
,
BatchEmbeddingOut
put
,
BatchTokenIDOut
,
BatchTokenIDOut
put
,
HealthCheckOutput
,
HealthCheckOutput
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
...
@@ -467,9 +467,9 @@ class GrpcRequestManager:
...
@@ -467,9 +467,9 @@ class GrpcRequestManager:
await
self
.
is_pause_cond
.
wait
()
await
self
.
is_pause_cond
.
wait
()
# Handle different output types
# Handle different output types
if
isinstance
(
recv_obj
,
BatchTokenIDOut
):
if
isinstance
(
recv_obj
,
BatchTokenIDOut
put
):
await
self
.
_handle_batch_output
(
recv_obj
)
await
self
.
_handle_batch_output
(
recv_obj
)
elif
isinstance
(
recv_obj
,
BatchEmbeddingOut
):
elif
isinstance
(
recv_obj
,
BatchEmbeddingOut
put
):
await
self
.
_handle_embedding_output
(
recv_obj
)
await
self
.
_handle_embedding_output
(
recv_obj
)
elif
isinstance
(
recv_obj
,
HealthCheckOutput
):
elif
isinstance
(
recv_obj
,
HealthCheckOutput
):
await
self
.
_handle_health_check_output
(
recv_obj
)
await
self
.
_handle_health_check_output
(
recv_obj
)
...
@@ -498,7 +498,7 @@ class GrpcRequestManager:
...
@@ -498,7 +498,7 @@ class GrpcRequestManager:
def
_convert_logprob_style
(
def
_convert_logprob_style
(
self
,
self
,
state
:
GrpcReqState
,
state
:
GrpcReqState
,
batch_out
:
BatchTokenIDOut
,
batch_out
:
BatchTokenIDOut
put
,
batch_index
:
int
,
batch_index
:
int
,
):
):
"""
"""
...
@@ -545,7 +545,7 @@ class GrpcRequestManager:
...
@@ -545,7 +545,7 @@ class GrpcRequestManager:
batch_out
.
output_top_logprobs_idx
[
batch_index
]
batch_out
.
output_top_logprobs_idx
[
batch_index
]
)
)
async
def
_handle_batch_output
(
self
,
batch_out
:
BatchTokenIDOut
):
async
def
_handle_batch_output
(
self
,
batch_out
:
BatchTokenIDOut
put
):
"""Handle batch generation output from scheduler."""
"""Handle batch generation output from scheduler."""
# Process each request in the batch
# Process each request in the batch
for
i
,
rid
in
enumerate
(
batch_out
.
rids
):
for
i
,
rid
in
enumerate
(
batch_out
.
rids
):
...
@@ -666,7 +666,7 @@ class GrpcRequestManager:
...
@@ -666,7 +666,7 @@ class GrpcRequestManager:
asyncio
.
create_task
(
cleanup
())
asyncio
.
create_task
(
cleanup
())
async
def
_handle_embedding_output
(
self
,
batch_out
:
BatchEmbeddingOut
):
async
def
_handle_embedding_output
(
self
,
batch_out
:
BatchEmbeddingOut
put
):
"""Handle batch embedding output from scheduler."""
"""Handle batch embedding output from scheduler."""
for
i
,
rid
in
enumerate
(
batch_out
.
rids
):
for
i
,
rid
in
enumerate
(
batch_out
.
rids
):
if
rid
not
in
self
.
rid_to_state
:
if
rid
not
in
self
.
rid_to_state
:
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
3c699772
...
@@ -94,8 +94,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -94,8 +94,8 @@ from sglang.srt.managers.io_struct import (
VertexGenerateReqInput
,
VertexGenerateReqInput
,
)
)
from
sglang.srt.managers.multi_tokenizer_mixin
import
(
from
sglang.srt.managers.multi_tokenizer_mixin
import
(
MultiTokenizerManager
,
MultiTokenizerRouter
,
MultiTokenizerRouter
,
TokenizerWorker
,
get_main_process_id
,
get_main_process_id
,
monkey_patch_uvicorn_multiprocessing
,
monkey_patch_uvicorn_multiprocessing
,
read_from_shared_memory
,
read_from_shared_memory
,
...
@@ -127,9 +127,7 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
...
@@ -127,9 +127,7 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
# Store global states
# Store global states
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
_GlobalState
:
class
_GlobalState
:
tokenizer_manager
:
Union
[
tokenizer_manager
:
Union
[
TokenizerManager
,
MultiTokenizerRouter
,
TokenizerWorker
]
TokenizerManager
,
MultiTokenizerRouter
,
MultiTokenizerManager
]
template_manager
:
TemplateManager
template_manager
:
TemplateManager
scheduler_info
:
Dict
scheduler_info
:
Dict
...
@@ -164,7 +162,7 @@ async def init_multi_tokenizer() -> ServerArgs:
...
@@ -164,7 +162,7 @@ async def init_multi_tokenizer() -> ServerArgs:
)
)
# Launch multi-tokenizer manager process
# Launch multi-tokenizer manager process
tokenizer_manager
=
Multi
Tokenizer
Manag
er
(
server_args
,
port_args
)
tokenizer_manager
=
Tokenizer
Work
er
(
server_args
,
port_args
)
template_manager
=
TemplateManager
()
template_manager
=
TemplateManager
()
template_manager
.
initialize_templates
(
template_manager
.
initialize_templates
(
tokenizer_manager
=
tokenizer_manager
,
tokenizer_manager
=
tokenizer_manager
,
...
...
python/sglang/srt/lora/lora_manager.py
View file @
3c699772
...
@@ -35,7 +35,7 @@ from sglang.srt.lora.utils import (
...
@@ -35,7 +35,7 @@ from sglang.srt.lora.utils import (
get_normalized_target_modules
,
get_normalized_target_modules
,
get_target_module_name
,
get_target_module_name
,
)
)
from
sglang.srt.managers.io_struct
import
LoRAUpdate
Resul
t
from
sglang.srt.managers.io_struct
import
LoRAUpdate
Outpu
t
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
replace_submodule
from
sglang.srt.utils
import
replace_submodule
...
@@ -107,8 +107,8 @@ class LoRAManager:
...
@@ -107,8 +107,8 @@ class LoRAManager:
def
create_lora_update_result
(
def
create_lora_update_result
(
self
,
success
:
bool
,
error_message
:
str
=
""
self
,
success
:
bool
,
error_message
:
str
=
""
)
->
LoRAUpdate
Resul
t
:
)
->
LoRAUpdate
Outpu
t
:
return
LoRAUpdate
Resul
t
(
return
LoRAUpdate
Outpu
t
(
success
=
success
,
success
=
success
,
error_message
=
error_message
,
error_message
=
error_message
,
loaded_adapters
=
{
loaded_adapters
=
{
...
@@ -117,7 +117,7 @@ class LoRAManager:
...
@@ -117,7 +117,7 @@ class LoRAManager:
},
},
)
)
def
load_lora_adapter
(
self
,
lora_ref
:
LoRARef
)
->
LoRAUpdate
Resul
t
:
def
load_lora_adapter
(
self
,
lora_ref
:
LoRARef
)
->
LoRAUpdate
Outpu
t
:
"""
"""
Load a single LoRA adapter from the specified path.
Load a single LoRA adapter from the specified path.
...
@@ -174,7 +174,7 @@ class LoRAManager:
...
@@ -174,7 +174,7 @@ class LoRAManager:
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
)
)
def
unload_lora_adapter
(
self
,
lora_ref
:
LoRARef
)
->
LoRAUpdate
Resul
t
:
def
unload_lora_adapter
(
self
,
lora_ref
:
LoRARef
)
->
LoRAUpdate
Outpu
t
:
"""
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules.
delete the corresponding LoRA modules.
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
3c699772
...
@@ -26,11 +26,11 @@ import zmq
...
@@ -26,11 +26,11 @@ import zmq
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
BatchEmbeddingOut
,
BatchEmbeddingOut
put
,
BatchMultimodalDecodeReq
,
BatchMultimodalDecodeReq
,
BatchMultimodalOut
,
BatchMultimodalOut
put
,
BatchStrOut
,
BatchStrOut
put
,
BatchTokenIDOut
,
BatchTokenIDOut
put
,
FreezeGCReq
,
FreezeGCReq
,
MultiTokenizerRegisterReq
,
MultiTokenizerRegisterReq
,
)
)
...
@@ -101,8 +101,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
...
@@ -101,8 +101,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
self
.
_request_dispatcher
=
TypeBasedDispatcher
(
self
.
_request_dispatcher
=
TypeBasedDispatcher
(
[
[
(
BatchEmbeddingOut
,
self
.
handle_batch_embedding_out
),
(
BatchEmbeddingOut
put
,
self
.
handle_batch_embedding_out
),
(
BatchTokenIDOut
,
self
.
handle_batch_token_id_out
),
(
BatchTokenIDOut
put
,
self
.
handle_batch_token_id_out
),
(
BatchMultimodalDecodeReq
,
self
.
handle_multimodal_decode_req
),
(
BatchMultimodalDecodeReq
,
self
.
handle_multimodal_decode_req
),
(
MultiTokenizerRegisterReq
,
lambda
x
:
x
),
(
MultiTokenizerRegisterReq
,
lambda
x
:
x
),
(
FreezeGCReq
,
self
.
handle_freeze_gc_req
),
(
FreezeGCReq
,
self
.
handle_freeze_gc_req
),
...
@@ -145,11 +145,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
...
@@ -145,11 +145,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
return
output
[:
-
1
]
return
output
[:
-
1
]
return
output
return
output
def
handle_batch_embedding_out
(
self
,
recv_obj
:
BatchEmbeddingOut
):
def
handle_batch_embedding_out
(
self
,
recv_obj
:
BatchEmbeddingOut
put
):
# If it is embedding model, no detokenization is needed.
# If it is embedding model, no detokenization is needed.
return
recv_obj
return
recv_obj
def
handle_batch_token_id_out
(
self
,
recv_obj
:
BatchTokenIDOut
):
def
handle_batch_token_id_out
(
self
,
recv_obj
:
BatchTokenIDOut
put
):
bs
=
len
(
recv_obj
.
rids
)
bs
=
len
(
recv_obj
.
rids
)
# Initialize decode status
# Initialize decode status
...
@@ -224,7 +224,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
...
@@ -224,7 +224,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
s
.
sent_offset
=
len
(
output_str
)
s
.
sent_offset
=
len
(
output_str
)
output_strs
.
append
(
incremental_output
)
output_strs
.
append
(
incremental_output
)
return
BatchStrOut
(
return
BatchStrOut
put
(
rids
=
recv_obj
.
rids
,
rids
=
recv_obj
.
rids
,
finished_reasons
=
recv_obj
.
finished_reasons
,
finished_reasons
=
recv_obj
.
finished_reasons
,
output_strs
=
output_strs
,
output_strs
=
output_strs
,
...
@@ -252,7 +252,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
...
@@ -252,7 +252,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
def
handle_multimodal_decode_req
(
self
,
recv_obj
:
BatchMultimodalDecodeReq
):
def
handle_multimodal_decode_req
(
self
,
recv_obj
:
BatchMultimodalDecodeReq
):
outputs
=
self
.
tokenizer
.
detokenize
(
recv_obj
)
outputs
=
self
.
tokenizer
.
detokenize
(
recv_obj
)
return
BatchMultimodalOut
(
return
BatchMultimodalOut
put
(
rids
=
recv_obj
.
rids
,
rids
=
recv_obj
.
rids
,
finished_reasons
=
recv_obj
.
finished_reasons
,
finished_reasons
=
recv_obj
.
finished_reasons
,
outputs
=
outputs
,
outputs
=
outputs
,
...
...
python/sglang/srt/managers/io_struct.py
View file @
3c699772
...
@@ -18,6 +18,7 @@ processes (TokenizerManager, DetokenizerManager, Scheduler).
...
@@ -18,6 +18,7 @@ processes (TokenizerManager, DetokenizerManager, Scheduler).
import
copy
import
copy
import
uuid
import
uuid
from
abc
import
ABC
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
...
@@ -36,10 +37,32 @@ else:
...
@@ -36,10 +37,32 @@ else:
# Parameters for a session
# Parameters for a session
@
dataclass
class
BaseReq
(
ABC
):
rid
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
field
(
default
=
None
,
kw_only
=
True
)
def
regenerate_rid
(
self
):
"""Generate a new request ID and return it."""
if
isinstance
(
self
.
rid
,
list
):
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
len
(
self
.
rid
))]
else
:
self
.
rid
=
uuid
.
uuid4
().
hex
return
self
.
rid
@
dataclass
class
BaseBatchReq
(
ABC
):
rids
:
Optional
[
List
[
str
]]
=
field
(
default
=
None
,
kw_only
=
True
)
def
regenerate_rids
(
self
):
"""Generate new request IDs and return them."""
self
.
rids
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
len
(
self
.
rids
))]
return
self
.
rids
@
dataclass
@
dataclass
class
SessionParams
:
class
SessionParams
:
id
:
Optional
[
str
]
=
None
id
:
Optional
[
str
]
=
None
rid
:
Optional
[
str
]
=
None
offset
:
Optional
[
int
]
=
None
offset
:
Optional
[
int
]
=
None
replace
:
Optional
[
bool
]
=
None
replace
:
Optional
[
bool
]
=
None
drop_previous_output
:
Optional
[
bool
]
=
None
drop_previous_output
:
Optional
[
bool
]
=
None
...
@@ -63,7 +86,7 @@ MultimodalDataInputFormat = Union[
...
@@ -63,7 +86,7 @@ MultimodalDataInputFormat = Union[
@
dataclass
@
dataclass
class
GenerateReqInput
:
class
GenerateReqInput
(
BaseReq
)
:
# The input prompt. It can be a single prompt or a batch of prompts.
# The input prompt. It can be a single prompt or a batch of prompts.
text
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
text
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The token ids for text; one can specify either text or input_ids
# The token ids for text; one can specify either text or input_ids
...
@@ -83,8 +106,6 @@ class GenerateReqInput:
...
@@ -83,8 +106,6 @@ class GenerateReqInput:
audio_data
:
Optional
[
MultimodalDataInputFormat
]
=
None
audio_data
:
Optional
[
MultimodalDataInputFormat
]
=
None
# The sampling_params. See descriptions below.
# The sampling_params. See descriptions below.
sampling_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
sampling_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
# The request id.
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Whether to return logprobs.
# Whether to return logprobs.
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
# If return logprobs, the start location in the prompt for returning logprobs.
# If return logprobs, the start location in the prompt for returning logprobs.
...
@@ -491,11 +512,6 @@ class GenerateReqInput:
...
@@ -491,11 +512,6 @@ class GenerateReqInput:
):
):
raise
ValueError
(
"Session params must be a dict or a list of dicts."
)
raise
ValueError
(
"Session params must be a dict or a list of dicts."
)
def
regenerate_rid
(
self
):
"""Generate a new request ID and return it."""
self
.
rid
=
uuid
.
uuid4
().
hex
return
self
.
rid
def
__getitem__
(
self
,
i
):
def
__getitem__
(
self
,
i
):
return
GenerateReqInput
(
return
GenerateReqInput
(
text
=
self
.
text
[
i
]
if
self
.
text
is
not
None
else
None
,
text
=
self
.
text
[
i
]
if
self
.
text
is
not
None
else
None
,
...
@@ -558,9 +574,7 @@ class GenerateReqInput:
...
@@ -558,9 +574,7 @@ class GenerateReqInput:
@
dataclass
@
dataclass
class
TokenizedGenerateReqInput
:
class
TokenizedGenerateReqInput
(
BaseReq
):
# The request id
rid
:
str
# The input text
# The input text
input_text
:
str
input_text
:
str
# The input token ids
# The input token ids
...
@@ -625,7 +639,7 @@ class TokenizedGenerateReqInput:
...
@@ -625,7 +639,7 @@ class TokenizedGenerateReqInput:
@
dataclass
@
dataclass
class
BatchTokenizedGenerateReqInput
:
class
BatchTokenizedGenerateReqInput
(
BaseBatchReq
)
:
# The batch of tokenized requests
# The batch of tokenized requests
batch
:
List
[
TokenizedGenerateReqInput
]
batch
:
List
[
TokenizedGenerateReqInput
]
...
@@ -640,7 +654,7 @@ class BatchTokenizedGenerateReqInput:
...
@@ -640,7 +654,7 @@ class BatchTokenizedGenerateReqInput:
@
dataclass
@
dataclass
class
EmbeddingReqInput
:
class
EmbeddingReqInput
(
BaseReq
)
:
# The input prompt. It can be a single prompt or a batch of prompts.
# The input prompt. It can be a single prompt or a batch of prompts.
text
:
Optional
[
Union
[
List
[
List
[
str
]],
List
[
str
],
str
]]
=
None
text
:
Optional
[
Union
[
List
[
List
[
str
]],
List
[
str
],
str
]]
=
None
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
...
@@ -656,8 +670,6 @@ class EmbeddingReqInput:
...
@@ -656,8 +670,6 @@ class EmbeddingReqInput:
audio_data
:
Optional
[
MultimodalDataInputFormat
]
=
None
audio_data
:
Optional
[
MultimodalDataInputFormat
]
=
None
# The token ids for text; one can either specify text or input_ids.
# The token ids for text; one can either specify text or input_ids.
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
# The request id.
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Dummy sampling params for compatibility
# Dummy sampling params for compatibility
sampling_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
sampling_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
# Dummy input embeds for compatibility
# Dummy input embeds for compatibility
...
@@ -728,10 +740,6 @@ class EmbeddingReqInput:
...
@@ -728,10 +740,6 @@ class EmbeddingReqInput:
for
i
in
range
(
self
.
batch_size
):
for
i
in
range
(
self
.
batch_size
):
self
.
sampling_params
[
i
][
"max_new_tokens"
]
=
0
self
.
sampling_params
[
i
][
"max_new_tokens"
]
=
0
def
regenerate_rid
(
self
):
self
.
rid
=
uuid
.
uuid4
().
hex
return
self
.
rid
def
contains_mm_input
(
self
)
->
bool
:
def
contains_mm_input
(
self
)
->
bool
:
return
(
return
(
has_valid_data
(
self
.
image_data
)
has_valid_data
(
self
.
image_data
)
...
@@ -760,9 +768,7 @@ class EmbeddingReqInput:
...
@@ -760,9 +768,7 @@ class EmbeddingReqInput:
@
dataclass
@
dataclass
class
TokenizedEmbeddingReqInput
:
class
TokenizedEmbeddingReqInput
(
BaseReq
):
# The request id
rid
:
str
# The input text
# The input text
input_text
:
str
input_text
:
str
# The input token ids
# The input token ids
...
@@ -780,7 +786,7 @@ class TokenizedEmbeddingReqInput:
...
@@ -780,7 +786,7 @@ class TokenizedEmbeddingReqInput:
@
dataclass
@
dataclass
class
BatchTokenizedEmbeddingReqInput
:
class
BatchTokenizedEmbeddingReqInput
(
BaseBatchReq
)
:
# The batch of tokenized embedding requests
# The batch of tokenized embedding requests
batch
:
List
[
TokenizedEmbeddingReqInput
]
batch
:
List
[
TokenizedEmbeddingReqInput
]
...
@@ -795,9 +801,7 @@ class BatchTokenizedEmbeddingReqInput:
...
@@ -795,9 +801,7 @@ class BatchTokenizedEmbeddingReqInput:
@
dataclass
@
dataclass
class
BatchTokenIDOut
:
class
BatchTokenIDOutput
(
BaseBatchReq
):
# The request id
rids
:
List
[
str
]
# The finish reason
# The finish reason
finished_reasons
:
List
[
BaseFinishReason
]
finished_reasons
:
List
[
BaseFinishReason
]
# For incremental decoding
# For incremental decoding
...
@@ -842,7 +846,7 @@ class BatchTokenIDOut:
...
@@ -842,7 +846,7 @@ class BatchTokenIDOut:
@
dataclass
@
dataclass
class
BatchMultimodalDecodeReq
:
class
BatchMultimodalDecodeReq
(
BaseBatchReq
)
:
decoded_ids
:
List
[
int
]
decoded_ids
:
List
[
int
]
input_token_logprobs_val
:
List
[
float
]
input_token_logprobs_val
:
List
[
float
]
input_token_logprobs_idx
:
List
[
int
]
input_token_logprobs_idx
:
List
[
int
]
...
@@ -854,8 +858,6 @@ class BatchMultimodalDecodeReq:
...
@@ -854,8 +858,6 @@ class BatchMultimodalDecodeReq:
image_resolutions
:
List
[
List
[
int
]]
image_resolutions
:
List
[
List
[
int
]]
resize_image_resolutions
:
List
[
List
[
int
]]
resize_image_resolutions
:
List
[
List
[
int
]]
# The request id
rids
:
List
[
str
]
finished_reasons
:
List
[
BaseFinishReason
]
finished_reasons
:
List
[
BaseFinishReason
]
# Token counts
# Token counts
...
@@ -871,9 +873,7 @@ class BatchMultimodalDecodeReq:
...
@@ -871,9 +873,7 @@ class BatchMultimodalDecodeReq:
@
dataclass
@
dataclass
class
BatchStrOut
:
class
BatchStrOutput
(
BaseBatchReq
):
# The request id
rids
:
List
[
str
]
# The finish reason
# The finish reason
finished_reasons
:
List
[
dict
]
finished_reasons
:
List
[
dict
]
# The output decoded strings
# The output decoded strings
...
@@ -909,9 +909,7 @@ class BatchStrOut:
...
@@ -909,9 +909,7 @@ class BatchStrOut:
@
dataclass
@
dataclass
class
BatchMultimodalOut
:
class
BatchMultimodalOutput
(
BaseBatchReq
):
# The request id
rids
:
List
[
str
]
# The finish reason
# The finish reason
finished_reasons
:
List
[
dict
]
finished_reasons
:
List
[
dict
]
decoded_ids
:
List
[
List
[
int
]]
decoded_ids
:
List
[
List
[
int
]]
...
@@ -936,9 +934,7 @@ class BatchMultimodalOut:
...
@@ -936,9 +934,7 @@ class BatchMultimodalOut:
@
dataclass
@
dataclass
class
BatchEmbeddingOut
:
class
BatchEmbeddingOutput
(
BaseBatchReq
):
# The request id
rids
:
List
[
str
]
# The finish reason
# The finish reason
finished_reasons
:
List
[
BaseFinishReason
]
finished_reasons
:
List
[
BaseFinishReason
]
# The output embedding
# The output embedding
...
@@ -952,27 +948,27 @@ class BatchEmbeddingOut:
...
@@ -952,27 +948,27 @@ class BatchEmbeddingOut:
@
dataclass
@
dataclass
class
ClearHiCacheReqInput
:
class
ClearHiCacheReqInput
(
BaseReq
)
:
pass
pass
@
dataclass
@
dataclass
class
ClearHiCacheReqOutput
:
class
ClearHiCacheReqOutput
(
BaseReq
)
:
success
:
bool
success
:
bool
@
dataclass
@
dataclass
class
FlushCacheReqInput
:
class
FlushCacheReqInput
(
BaseReq
)
:
pass
pass
@
dataclass
@
dataclass
class
FlushCacheReqOutput
:
class
FlushCacheReqOutput
(
BaseReq
)
:
success
:
bool
success
:
bool
@
dataclass
@
dataclass
class
UpdateWeightFromDiskReqInput
:
class
UpdateWeightFromDiskReqInput
(
BaseReq
)
:
# The model path with the new weights
# The model path with the new weights
model_path
:
str
model_path
:
str
# The format to load the weights
# The format to load the weights
...
@@ -990,7 +986,7 @@ class UpdateWeightFromDiskReqInput:
...
@@ -990,7 +986,7 @@ class UpdateWeightFromDiskReqInput:
@
dataclass
@
dataclass
class
UpdateWeightFromDiskReqOutput
:
class
UpdateWeightFromDiskReqOutput
(
BaseReq
)
:
success
:
bool
success
:
bool
message
:
str
message
:
str
# Number of paused requests during weight sync.
# Number of paused requests during weight sync.
...
@@ -998,7 +994,7 @@ class UpdateWeightFromDiskReqOutput:
...
@@ -998,7 +994,7 @@ class UpdateWeightFromDiskReqOutput:
@
dataclass
@
dataclass
class
UpdateWeightsFromDistributedReqInput
:
class
UpdateWeightsFromDistributedReqInput
(
BaseReq
)
:
names
:
List
[
str
]
names
:
List
[
str
]
dtypes
:
List
[
str
]
dtypes
:
List
[
str
]
shapes
:
List
[
List
[
int
]]
shapes
:
List
[
List
[
int
]]
...
@@ -1013,13 +1009,13 @@ class UpdateWeightsFromDistributedReqInput:
...
@@ -1013,13 +1009,13 @@ class UpdateWeightsFromDistributedReqInput:
@
dataclass
@
dataclass
class
UpdateWeightsFromDistributedReqOutput
:
class
UpdateWeightsFromDistributedReqOutput
(
BaseReq
)
:
success
:
bool
success
:
bool
message
:
str
message
:
str
@
dataclass
@
dataclass
class
UpdateWeightsFromTensorReqInput
:
class
UpdateWeightsFromTensorReqInput
(
BaseReq
)
:
"""Update model weights from tensor input.
"""Update model weights from tensor input.
- Tensors are serialized for transmission
- Tensors are serialized for transmission
...
@@ -1038,13 +1034,13 @@ class UpdateWeightsFromTensorReqInput:
...
@@ -1038,13 +1034,13 @@ class UpdateWeightsFromTensorReqInput:
@
dataclass
@
dataclass
class
UpdateWeightsFromTensorReqOutput
:
class
UpdateWeightsFromTensorReqOutput
(
BaseReq
)
:
success
:
bool
success
:
bool
message
:
str
message
:
str
@
dataclass
@
dataclass
class
InitWeightsSendGroupForRemoteInstanceReqInput
:
class
InitWeightsSendGroupForRemoteInstanceReqInput
(
BaseReq
)
:
# The master address
# The master address
master_address
:
str
master_address
:
str
# The ports for each rank's communication group
# The ports for each rank's communication group
...
@@ -1060,13 +1056,13 @@ class InitWeightsSendGroupForRemoteInstanceReqInput:
...
@@ -1060,13 +1056,13 @@ class InitWeightsSendGroupForRemoteInstanceReqInput:
@
dataclass
@
dataclass
class
InitWeightsSendGroupForRemoteInstanceReqOutput
:
class
InitWeightsSendGroupForRemoteInstanceReqOutput
(
BaseReq
)
:
success
:
bool
success
:
bool
message
:
str
message
:
str
@
dataclass
@
dataclass
class
SendWeightsToRemoteInstanceReqInput
:
class
SendWeightsToRemoteInstanceReqInput
(
BaseReq
)
:
# The master address
# The master address
master_address
:
str
master_address
:
str
# The ports for each rank's communication group
# The ports for each rank's communication group
...
@@ -1076,13 +1072,13 @@ class SendWeightsToRemoteInstanceReqInput:
...
@@ -1076,13 +1072,13 @@ class SendWeightsToRemoteInstanceReqInput:
@
dataclass
@
dataclass
class
SendWeightsToRemoteInstanceReqOutput
:
class
SendWeightsToRemoteInstanceReqOutput
(
BaseReq
)
:
success
:
bool
success
:
bool
message
:
str
message
:
str
@
dataclass
@
dataclass
class
InitWeightsUpdateGroupReqInput
:
class
InitWeightsUpdateGroupReqInput
(
BaseReq
)
:
# The master address
# The master address
master_address
:
str
master_address
:
str
# The master port
# The master port
...
@@ -1098,24 +1094,24 @@ class InitWeightsUpdateGroupReqInput:
...
@@ -1098,24 +1094,24 @@ class InitWeightsUpdateGroupReqInput:
@
dataclass
@
dataclass
class
InitWeightsUpdateGroupReqOutput
:
class
InitWeightsUpdateGroupReqOutput
(
BaseReq
)
:
success
:
bool
success
:
bool
message
:
str
message
:
str
@
dataclass
@
dataclass
class
DestroyWeightsUpdateGroupReqInput
:
class
DestroyWeightsUpdateGroupReqInput
(
BaseReq
)
:
group_name
:
str
=
"weight_update_group"
group_name
:
str
=
"weight_update_group"
@
dataclass
@
dataclass
class
DestroyWeightsUpdateGroupReqOutput
:
class
DestroyWeightsUpdateGroupReqOutput
(
BaseReq
)
:
success
:
bool
success
:
bool
message
:
str
message
:
str
@
dataclass
@
dataclass
class
UpdateWeightVersionReqInput
:
class
UpdateWeightVersionReqInput
(
BaseReq
)
:
# The new weight version
# The new weight version
new_version
:
str
new_version
:
str
# Whether to abort all running requests before updating
# Whether to abort all running requests before updating
...
@@ -1123,89 +1119,87 @@ class UpdateWeightVersionReqInput:
...
@@ -1123,89 +1119,87 @@ class UpdateWeightVersionReqInput:
@
dataclass
@
dataclass
class
GetWeightsByNameReqInput
:
class
GetWeightsByNameReqInput
(
BaseReq
)
:
name
:
str
name
:
str
truncate_size
:
int
=
100
truncate_size
:
int
=
100
@
dataclass
@
dataclass
class
GetWeightsByNameReqOutput
:
class
GetWeightsByNameReqOutput
(
BaseReq
)
:
parameter
:
list
parameter
:
list
@
dataclass
@
dataclass
class
ReleaseMemoryOccupationReqInput
:
class
ReleaseMemoryOccupationReqInput
(
BaseReq
)
:
# Optional tags to identify the memory region, which is primarily used for RL
# Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
# Currently we only support `weights` and `kv_cache`
tags
:
Optional
[
List
[
str
]]
=
None
tags
:
Optional
[
List
[
str
]]
=
None
@
dataclass
@
dataclass
class
ReleaseMemoryOccupationReqOutput
:
class
ReleaseMemoryOccupationReqOutput
(
BaseReq
)
:
pass
pass
@
dataclass
@
dataclass
class
ResumeMemoryOccupationReqInput
:
class
ResumeMemoryOccupationReqInput
(
BaseReq
)
:
# Optional tags to identify the memory region, which is primarily used for RL
# Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
# Currently we only support `weights` and `kv_cache`
tags
:
Optional
[
List
[
str
]]
=
None
tags
:
Optional
[
List
[
str
]]
=
None
@
dataclass
@
dataclass
class
ResumeMemoryOccupationReqOutput
:
class
ResumeMemoryOccupationReqOutput
(
BaseReq
)
:
pass
pass
@
dataclass
@
dataclass
class
SlowDownReqInput
:
class
SlowDownReqInput
(
BaseReq
)
:
forward_sleep_time
:
Optional
[
float
]
forward_sleep_time
:
Optional
[
float
]
@
dataclass
@
dataclass
class
SlowDownReqOutput
:
class
SlowDownReqOutput
(
BaseReq
)
:
pass
pass
@
dataclass
@
dataclass
class
AbortReq
:
class
AbortReq
(
BaseReq
):
# The request id
rid
:
str
=
""
# Whether to abort all requests
# Whether to abort all requests
abort_all
:
bool
=
False
abort_all
:
bool
=
False
# The finished reason data
# The finished reason data
finished_reason
:
Optional
[
Dict
[
str
,
Any
]]
=
None
finished_reason
:
Optional
[
Dict
[
str
,
Any
]]
=
None
abort_reason
:
Optional
[
str
]
=
None
abort_reason
:
Optional
[
str
]
=
None
# used in MultiTokenzierManager mode
rids
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
self
.
rids
=
self
.
rid
# FIXME: This is a hack to keep the same with the old code
if
self
.
rid
is
None
:
self
.
rid
=
""
@
dataclass
@
dataclass
class
GetInternalStateReq
:
class
GetInternalStateReq
(
BaseReq
)
:
pass
pass
@
dataclass
@
dataclass
class
GetInternalStateReqOutput
:
class
GetInternalStateReqOutput
(
BaseReq
)
:
internal_state
:
Dict
[
Any
,
Any
]
internal_state
:
Dict
[
Any
,
Any
]
@
dataclass
@
dataclass
class
SetInternalStateReq
:
class
SetInternalStateReq
(
BaseReq
)
:
server_args
:
Dict
[
str
,
Any
]
server_args
:
Dict
[
str
,
Any
]
@
dataclass
@
dataclass
class
SetInternalStateReqOutput
:
class
SetInternalStateReqOutput
(
BaseReq
)
:
updated
:
bool
updated
:
bool
server_args
:
Dict
[
str
,
Any
]
server_args
:
Dict
[
str
,
Any
]
@
dataclass
@
dataclass
class
ProfileReqInput
:
class
ProfileReqInput
(
BaseReq
)
:
# The output directory
# The output directory
output_dir
:
Optional
[
str
]
=
None
output_dir
:
Optional
[
str
]
=
None
# If set, it profile as many as this number of steps.
# If set, it profile as many as this number of steps.
...
@@ -1225,7 +1219,7 @@ class ProfileReqType(Enum):
...
@@ -1225,7 +1219,7 @@ class ProfileReqType(Enum):
@
dataclass
@
dataclass
class
ProfileReq
:
class
ProfileReq
(
BaseReq
)
:
type
:
ProfileReqType
type
:
ProfileReqType
output_dir
:
Optional
[
str
]
=
None
output_dir
:
Optional
[
str
]
=
None
start_step
:
Optional
[
int
]
=
None
start_step
:
Optional
[
int
]
=
None
...
@@ -1238,18 +1232,18 @@ class ProfileReq:
...
@@ -1238,18 +1232,18 @@ class ProfileReq:
@
dataclass
@
dataclass
class
ProfileReqOutput
:
class
ProfileReqOutput
(
BaseReq
)
:
success
:
bool
success
:
bool
message
:
str
message
:
str
@
dataclass
@
dataclass
class
FreezeGCReq
:
class
FreezeGCReq
(
BaseReq
)
:
pass
pass
@
dataclass
@
dataclass
class
ConfigureLoggingReq
:
class
ConfigureLoggingReq
(
BaseReq
)
:
log_requests
:
Optional
[
bool
]
=
None
log_requests
:
Optional
[
bool
]
=
None
log_requests_level
:
Optional
[
int
]
=
None
log_requests_level
:
Optional
[
int
]
=
None
dump_requests_folder
:
Optional
[
str
]
=
None
dump_requests_folder
:
Optional
[
str
]
=
None
...
@@ -1258,35 +1252,39 @@ class ConfigureLoggingReq:
...
@@ -1258,35 +1252,39 @@ class ConfigureLoggingReq:
@
dataclass
@
dataclass
class
OpenSessionReqInput
:
class
OpenSessionReqInput
(
BaseReq
)
:
capacity_of_str_len
:
int
capacity_of_str_len
:
int
session_id
:
Optional
[
str
]
=
None
session_id
:
Optional
[
str
]
=
None
@
dataclass
@
dataclass
class
CloseSessionReqInput
:
class
CloseSessionReqInput
(
BaseReq
)
:
session_id
:
str
session_id
:
str
@
dataclass
@
dataclass
class
OpenSessionReqOutput
:
class
OpenSessionReqOutput
(
BaseReq
)
:
session_id
:
Optional
[
str
]
session_id
:
Optional
[
str
]
success
:
bool
success
:
bool
@
dataclass
@
dataclass
class
HealthCheckOutput
:
class
HealthCheckOutput
(
BaseReq
)
:
pass
pass
class
ExpertDistributionReq
(
Enum
):
class
ExpertDistributionReq
Type
(
Enum
):
START_RECORD
=
1
START_RECORD
=
1
STOP_RECORD
=
2
STOP_RECORD
=
2
DUMP_RECORD
=
3
DUMP_RECORD
=
3
class
ExpertDistributionReq
(
BaseReq
):
action
:
ExpertDistributionReqType
@
dataclass
@
dataclass
class
ExpertDistributionReqOutput
:
class
ExpertDistributionReqOutput
(
BaseReq
)
:
pass
pass
...
@@ -1304,7 +1302,7 @@ class Tool:
...
@@ -1304,7 +1302,7 @@ class Tool:
@
dataclass
@
dataclass
class
ParseFunctionCallReq
:
class
ParseFunctionCallReq
(
BaseReq
)
:
text
:
str
# The text to parse.
text
:
str
# The text to parse.
tools
:
List
[
Tool
]
=
field
(
tools
:
List
[
Tool
]
=
field
(
default_factory
=
list
default_factory
=
list
...
@@ -1315,31 +1313,31 @@ class ParseFunctionCallReq:
...
@@ -1315,31 +1313,31 @@ class ParseFunctionCallReq:
@
dataclass
@
dataclass
class
SeparateReasoningReqInput
:
class
SeparateReasoningReqInput
(
BaseReq
)
:
text
:
str
# The text to parse.
text
:
str
# The text to parse.
reasoning_parser
:
str
# Specify the parser type, e.g., "deepseek-r1".
reasoning_parser
:
str
# Specify the parser type, e.g., "deepseek-r1".
@
dataclass
@
dataclass
class
VertexGenerateReqInput
:
class
VertexGenerateReqInput
(
BaseReq
)
:
instances
:
List
[
dict
]
instances
:
List
[
dict
]
parameters
:
Optional
[
dict
]
=
None
parameters
:
Optional
[
dict
]
=
None
@
dataclass
@
dataclass
class
RpcReqInput
:
class
RpcReqInput
(
BaseReq
)
:
method
:
str
method
:
str
parameters
:
Optional
[
Dict
]
=
None
parameters
:
Optional
[
Dict
]
=
None
@
dataclass
@
dataclass
class
RpcReqOutput
:
class
RpcReqOutput
(
BaseReq
)
:
success
:
bool
success
:
bool
message
:
str
message
:
str
@
dataclass
@
dataclass
class
LoadLoRAAdapterReqInput
:
class
LoadLoRAAdapterReqInput
(
BaseReq
)
:
# The name of the lora module to newly loaded.
# The name of the lora module to newly loaded.
lora_name
:
str
lora_name
:
str
# The path of loading.
# The path of loading.
...
@@ -1359,7 +1357,7 @@ class LoadLoRAAdapterReqInput:
...
@@ -1359,7 +1357,7 @@ class LoadLoRAAdapterReqInput:
@
dataclass
@
dataclass
class
UnloadLoRAAdapterReqInput
:
class
UnloadLoRAAdapterReqInput
(
BaseReq
)
:
# The name of lora module to unload.
# The name of lora module to unload.
lora_name
:
str
lora_name
:
str
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
...
@@ -1373,23 +1371,23 @@ class UnloadLoRAAdapterReqInput:
...
@@ -1373,23 +1371,23 @@ class UnloadLoRAAdapterReqInput:
@
dataclass
@
dataclass
class
LoRAUpdate
Result
:
class
LoRAUpdate
Output
(
BaseReq
)
:
success
:
bool
success
:
bool
error_message
:
Optional
[
str
]
=
None
error_message
:
Optional
[
str
]
=
None
loaded_adapters
:
Optional
[
Dict
[
str
,
LoRARef
]]
=
None
loaded_adapters
:
Optional
[
Dict
[
str
,
LoRARef
]]
=
None
LoadLoRAAdapterReqOutput
=
UnloadLoRAAdapterReqOutput
=
LoRAUpdate
Resul
t
LoadLoRAAdapterReqOutput
=
UnloadLoRAAdapterReqOutput
=
LoRAUpdate
Outpu
t
@
dataclass
@
dataclass
class
MultiTokenizerRegisterReq
:
class
MultiTokenizerRegisterReq
(
BaseBatchReq
):
rids
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
ipc_name
:
Optional
[
str
]
=
None
ipc_name
:
Optional
[
str
]
=
None
@
dataclass
@
dataclass
class
MultiTokenizerWrapper
:
class
MultiTokenizerWrapper
:
# FIXME(lsyin): remove this
worker_id
:
int
worker_id
:
int
obj
:
Optional
[
Any
]
=
None
obj
:
Optional
[
Any
]
=
None
...
@@ -1400,17 +1398,17 @@ class BlockReqType(Enum):
...
@@ -1400,17 +1398,17 @@ class BlockReqType(Enum):
@
dataclass
@
dataclass
class
BlockReqInput
:
class
BlockReqInput
(
BaseReq
)
:
type
:
BlockReqType
type
:
BlockReqType
@
dataclass
@
dataclass
class
GetLoadReqInput
:
class
GetLoadReqInput
(
BaseReq
)
:
pass
pass
@
dataclass
@
dataclass
class
GetLoadReqOutput
:
class
GetLoadReqOutput
(
BaseReq
)
:
dp_rank
:
int
dp_rank
:
int
num_reqs
:
int
num_reqs
:
int
num_waiting_reqs
:
int
num_waiting_reqs
:
int
...
@@ -1418,5 +1416,31 @@ class GetLoadReqOutput:
...
@@ -1418,5 +1416,31 @@ class GetLoadReqOutput:
@
dataclass
@
dataclass
class
WatchLoadUpdateReq
:
class
WatchLoadUpdateReq
(
BaseReq
)
:
loads
:
List
[
GetLoadReqOutput
]
loads
:
List
[
GetLoadReqOutput
]
def
_check_all_req_types
():
"""A helper function to check all request types are defined in this file."""
import
inspect
import
sys
all_classes
=
inspect
.
getmembers
(
sys
.
modules
[
__name__
],
inspect
.
isclass
)
for
class_type
in
all_classes
:
# check its name
name
=
class_type
[
0
]
is_io_struct
=
(
name
.
endswith
(
"Req"
)
or
name
.
endswith
(
"Input"
)
or
name
.
endswith
(
"Output"
)
)
is_base_req
=
issubclass
(
class_type
[
1
],
BaseReq
)
or
issubclass
(
class_type
[
1
],
BaseBatchReq
)
if
is_io_struct
and
not
is_base_req
:
raise
ValueError
(
f
"
{
name
}
is not a subclass of BaseReq or BaseBatchReq."
)
if
is_base_req
and
not
is_io_struct
:
raise
ValueError
(
f
"
{
name
}
is a subclass of BaseReq but not follow the naming convention."
)
_check_all_req_types
()
python/sglang/srt/managers/multi_tokenizer_mixin.py
View file @
3c699772
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""M
ultiTokenizerMixin is a class that provides nesscary method
s for
M
ulti
TokenizerManager and DetokenizerManager.
"""
"""M
ixin class and util
s for
m
ulti
-http-worker mode
"""
import
asyncio
import
asyncio
import
logging
import
logging
import
multiprocessing
as
multiprocessing
import
multiprocessing
as
multiprocessing
...
@@ -30,10 +30,10 @@ import zmq.asyncio
...
@@ -30,10 +30,10 @@ import zmq.asyncio
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
,
TransferBackend
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
,
TransferBackend
from
sglang.srt.managers.disagg_service
import
start_disagg_service
from
sglang.srt.managers.disagg_service
import
start_disagg_service
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
BatchEmbeddingOut
,
BatchEmbeddingOut
put
,
BatchMultimodalOut
,
BatchMultimodalOut
put
,
BatchStrOut
,
BatchStrOut
put
,
BatchTokenIDOut
,
BatchTokenIDOut
put
,
MultiTokenizerRegisterReq
,
MultiTokenizerRegisterReq
,
MultiTokenizerWrapper
,
MultiTokenizerWrapper
,
)
)
...
@@ -83,8 +83,8 @@ class SocketMapping:
...
@@ -83,8 +83,8 @@ class SocketMapping:
def
_handle_output_by_index
(
output
,
i
):
def
_handle_output_by_index
(
output
,
i
):
"""NOTE: A maintainable method is better here."""
"""NOTE: A maintainable method is better here."""
if
isinstance
(
output
,
BatchTokenIDOut
):
if
isinstance
(
output
,
BatchTokenIDOut
put
):
new_output
=
BatchTokenIDOut
(
new_output
=
BatchTokenIDOut
put
(
rids
=
[
output
.
rids
[
i
]],
rids
=
[
output
.
rids
[
i
]],
finished_reasons
=
(
finished_reasons
=
(
[
output
.
finished_reasons
[
i
]]
[
output
.
finished_reasons
[
i
]]
...
@@ -198,8 +198,8 @@ def _handle_output_by_index(output, i):
...
@@ -198,8 +198,8 @@ def _handle_output_by_index(output, i):
placeholder_tokens_idx
=
None
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
placeholder_tokens_val
=
None
,
)
)
elif
isinstance
(
output
,
BatchEmbeddingOut
):
elif
isinstance
(
output
,
BatchEmbeddingOut
put
):
new_output
=
BatchEmbeddingOut
(
new_output
=
BatchEmbeddingOut
put
(
rids
=
[
output
.
rids
[
i
]],
rids
=
[
output
.
rids
[
i
]],
finished_reasons
=
(
finished_reasons
=
(
[
output
.
finished_reasons
[
i
]]
[
output
.
finished_reasons
[
i
]]
...
@@ -216,8 +216,8 @@ def _handle_output_by_index(output, i):
...
@@ -216,8 +216,8 @@ def _handle_output_by_index(output, i):
placeholder_tokens_idx
=
None
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
placeholder_tokens_val
=
None
,
)
)
elif
isinstance
(
output
,
BatchStrOut
):
elif
isinstance
(
output
,
BatchStrOut
put
):
new_output
=
BatchStrOut
(
new_output
=
BatchStrOut
put
(
rids
=
[
output
.
rids
[
i
]],
rids
=
[
output
.
rids
[
i
]],
finished_reasons
=
(
finished_reasons
=
(
[
output
.
finished_reasons
[
i
]]
[
output
.
finished_reasons
[
i
]]
...
@@ -314,8 +314,8 @@ def _handle_output_by_index(output, i):
...
@@ -314,8 +314,8 @@ def _handle_output_by_index(output, i):
placeholder_tokens_idx
=
None
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
placeholder_tokens_val
=
None
,
)
)
elif
isinstance
(
output
,
BatchMultimodalOut
):
elif
isinstance
(
output
,
BatchMultimodalOut
put
):
new_output
=
BatchMultimodalOut
(
new_output
=
BatchMultimodalOut
put
(
rids
=
[
output
.
rids
[
i
]],
rids
=
[
output
.
rids
[
i
]],
finished_reasons
=
(
finished_reasons
=
(
[
output
.
finished_reasons
[
i
]]
[
output
.
finished_reasons
[
i
]]
...
@@ -343,7 +343,7 @@ def _handle_output_by_index(output, i):
...
@@ -343,7 +343,7 @@ def _handle_output_by_index(output, i):
class
MultiHttpWorkerDetokenizerMixin
:
class
MultiHttpWorkerDetokenizerMixin
:
"""Mixin class for
MultiTokenizerManager and
DetokenizerManager"""
"""Mixin class for DetokenizerManager"""
def
get_worker_ids_from_req_rids
(
self
,
rids
):
def
get_worker_ids_from_req_rids
(
self
,
rids
):
if
isinstance
(
rids
,
list
):
if
isinstance
(
rids
,
list
):
...
@@ -386,7 +386,7 @@ class MultiHttpWorkerDetokenizerMixin:
...
@@ -386,7 +386,7 @@ class MultiHttpWorkerDetokenizerMixin:
class
MultiTokenizerRouter
:
class
MultiTokenizerRouter
:
"""A router to receive requests from
Multi
Tokenizer
Manag
er"""
"""A router to receive requests from Tokenizer
Work
er"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -454,8 +454,8 @@ class MultiTokenizerRouter:
...
@@ -454,8 +454,8 @@ class MultiTokenizerRouter:
self
.
socket_mapping
.
send_output
(
worker_id
,
new_recv_obj
)
self
.
socket_mapping
.
send_output
(
worker_id
,
new_recv_obj
)
class
Multi
Tokenizer
Manag
er
(
TokenizerManager
):
class
Tokenizer
Work
er
(
TokenizerManager
):
"""
Multi Process Tokenizer Manager that tokenizes the text.
"""
"""
Tokenizer Worker in multi-http-worker mode
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
3c699772
...
@@ -78,6 +78,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -78,6 +78,7 @@ from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqInput
,
DestroyWeightsUpdateGroupReqInput
,
ExpertDistributionReq
,
ExpertDistributionReq
,
ExpertDistributionReqOutput
,
ExpertDistributionReqOutput
,
ExpertDistributionReqType
,
FlushCacheReqInput
,
FlushCacheReqInput
,
FlushCacheReqOutput
,
FlushCacheReqOutput
,
FreezeGCReq
,
FreezeGCReq
,
...
@@ -1487,12 +1488,12 @@ class Scheduler(
...
@@ -1487,12 +1488,12 @@ class Scheduler(
req
.
priority
=
-
sys
.
maxsize
-
1
req
.
priority
=
-
sys
.
maxsize
-
1
elif
not
self
.
enable_priority_scheduling
and
req
.
priority
is
not
None
:
elif
not
self
.
enable_priority_scheduling
and
req
.
priority
is
not
None
:
abort_req
=
AbortReq
(
abort_req
=
AbortReq
(
req
.
rid
,
finished_reason
=
{
finished_reason
=
{
"type"
:
"abort"
,
"type"
:
"abort"
,
"status_code"
:
HTTPStatus
.
SERVICE_UNAVAILABLE
,
"status_code"
:
HTTPStatus
.
SERVICE_UNAVAILABLE
,
"message"
:
"Using priority is disabled for this server. Please send a new request without a priority."
,
"message"
:
"Using priority is disabled for this server. Please send a new request without a priority."
,
},
},
rid
=
req
.
rid
,
)
)
self
.
send_to_tokenizer
.
send_pyobj
(
abort_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
abort_req
)
...
@@ -1528,12 +1529,12 @@ class Scheduler(
...
@@ -1528,12 +1529,12 @@ class Scheduler(
self
.
send_to_tokenizer
.
send_pyobj
(
self
.
send_to_tokenizer
.
send_pyobj
(
AbortReq
(
AbortReq
(
req_to_abort
.
rid
,
finished_reason
=
{
finished_reason
=
{
"type"
:
"abort"
,
"type"
:
"abort"
,
"status_code"
:
HTTPStatus
.
SERVICE_UNAVAILABLE
,
"status_code"
:
HTTPStatus
.
SERVICE_UNAVAILABLE
,
"message"
:
message
,
"message"
:
message
,
},
},
rid
=
req_to_abort
.
rid
,
)
)
)
)
return
req_to_abort
.
rid
==
recv_req
.
rid
return
req_to_abort
.
rid
==
recv_req
.
rid
...
@@ -2005,7 +2006,7 @@ class Scheduler(
...
@@ -2005,7 +2006,7 @@ class Scheduler(
self
.
new_token_ratio
=
new_token_ratio
self
.
new_token_ratio
=
new_token_ratio
for
req
in
reqs_to_abort
:
for
req
in
reqs_to_abort
:
self
.
send_to_tokenizer
.
send_pyobj
(
self
.
send_to_tokenizer
.
send_pyobj
(
AbortReq
(
req
.
rid
,
abort_reason
=
req
.
to_abort_message
)
AbortReq
(
abort_reason
=
req
.
to_abort_message
,
rid
=
req
.
rid
)
)
)
logger
.
info
(
logger
.
info
(
...
@@ -2575,7 +2576,7 @@ class Scheduler(
...
@@ -2575,7 +2576,7 @@ class Scheduler(
if
self
.
enable_hicache_storage
:
if
self
.
enable_hicache_storage
:
# to release prefetch events associated with the request
# to release prefetch events associated with the request
self
.
tree_cache
.
release_aborted_request
(
req
.
rid
)
self
.
tree_cache
.
release_aborted_request
(
req
.
rid
)
self
.
send_to_tokenizer
.
send_pyobj
(
AbortReq
(
req
.
rid
))
self
.
send_to_tokenizer
.
send_pyobj
(
AbortReq
(
rid
=
req
.
rid
))
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
tree_cache
.
cache_finished_req
(
req
)
self
.
tree_cache
.
cache_finished_req
(
req
)
...
@@ -2687,11 +2688,12 @@ class Scheduler(
...
@@ -2687,11 +2688,12 @@ class Scheduler(
return
SlowDownReqOutput
()
return
SlowDownReqOutput
()
def
expert_distribution_handle
(
self
,
recv_req
:
ExpertDistributionReq
):
def
expert_distribution_handle
(
self
,
recv_req
:
ExpertDistributionReq
):
if
recv_req
==
ExpertDistributionReq
.
START_RECORD
:
action
=
recv_req
.
action
if
action
==
ExpertDistributionReqType
.
START_RECORD
:
get_global_expert_distribution_recorder
().
start_record
()
get_global_expert_distribution_recorder
().
start_record
()
elif
recv_req
==
ExpertDistributionReq
.
STOP_RECORD
:
elif
action
==
ExpertDistributionReq
Type
.
STOP_RECORD
:
get_global_expert_distribution_recorder
().
stop_record
()
get_global_expert_distribution_recorder
().
stop_record
()
elif
recv_req
==
ExpertDistributionReq
.
DUMP_RECORD
:
elif
action
==
ExpertDistributionReq
Type
.
DUMP_RECORD
:
get_global_expert_distribution_recorder
().
dump_record
()
get_global_expert_distribution_recorder
().
dump_record
()
else
:
else
:
raise
ValueError
(
f
"Unrecognized ExpertDistributionReq value:
{
recv_req
=
}
"
)
raise
ValueError
(
f
"Unrecognized ExpertDistributionReq value:
{
recv_req
=
}
"
)
...
@@ -2774,7 +2776,8 @@ class IdleSleeper:
...
@@ -2774,7 +2776,8 @@ class IdleSleeper:
def
is_health_check_generate_req
(
recv_req
):
def
is_health_check_generate_req
(
recv_req
):
return
getattr
(
recv_req
,
"rid"
,
""
).
startswith
(
"HEALTH_CHECK"
)
rid
=
getattr
(
recv_req
,
"rid"
,
None
)
return
rid
is
not
None
and
rid
.
startswith
(
"HEALTH_CHECK"
)
def
is_work_request
(
recv_req
):
def
is_work_request
(
recv_req
):
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
3c699772
...
@@ -9,7 +9,11 @@ import torch
...
@@ -9,7 +9,11 @@ import torch
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
AbortReq
,
BatchEmbeddingOut
,
BatchTokenIDOut
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOutput
,
BatchTokenIDOutput
,
)
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
,
Req
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
,
Req
,
ScheduleBatch
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -140,7 +144,7 @@ class SchedulerOutputProcessorMixin:
...
@@ -140,7 +144,7 @@ class SchedulerOutputProcessorMixin:
logger
.
error
(
logger
.
error
(
f
"Grammar accept_token failed for req
{
req
.
rid
}
with token
{
next_token_id
}
:
{
e
}
"
f
"Grammar accept_token failed for req
{
req
.
rid
}
with token
{
next_token_id
}
:
{
e
}
"
)
)
self
.
abort_request
(
AbortReq
(
req
.
rid
))
self
.
abort_request
(
AbortReq
(
rid
=
req
.
rid
))
req
.
grammar
.
finished
=
req
.
finished
()
req
.
grammar
.
finished
=
req
.
finished
()
else
:
else
:
# being chunked reqs' prefill is not finished
# being chunked reqs' prefill is not finished
...
@@ -292,7 +296,7 @@ class SchedulerOutputProcessorMixin:
...
@@ -292,7 +296,7 @@ class SchedulerOutputProcessorMixin:
logger
.
error
(
logger
.
error
(
f
"Grammar accept_token failed for req
{
req
.
rid
}
with token
{
next_token_id
}
:
{
e
}
"
f
"Grammar accept_token failed for req
{
req
.
rid
}
with token
{
next_token_id
}
:
{
e
}
"
)
)
self
.
abort_request
(
AbortReq
(
req
.
rid
))
self
.
abort_request
(
AbortReq
(
rid
=
req
.
rid
))
req
.
grammar
.
finished
=
req
.
finished
()
req
.
grammar
.
finished
=
req
.
finished
()
self
.
set_next_batch_sampling_info_done
(
batch
)
self
.
set_next_batch_sampling_info_done
(
batch
)
...
@@ -714,8 +718,7 @@ class SchedulerOutputProcessorMixin:
...
@@ -714,8 +718,7 @@ class SchedulerOutputProcessorMixin:
return
return
self
.
send_to_detokenizer
.
send_pyobj
(
self
.
send_to_detokenizer
.
send_pyobj
(
BatchTokenIDOut
(
BatchTokenIDOutput
(
rids
,
finished_reasons
,
finished_reasons
,
decoded_texts
,
decoded_texts
,
decode_ids_list
,
decode_ids_list
,
...
@@ -741,6 +744,7 @@ class SchedulerOutputProcessorMixin:
...
@@ -741,6 +744,7 @@ class SchedulerOutputProcessorMixin:
output_token_ids_logprobs_val
,
output_token_ids_logprobs_val
,
output_token_ids_logprobs_idx
,
output_token_ids_logprobs_idx
,
output_hidden_states
,
output_hidden_states
,
rids
=
rids
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
placeholder_tokens_val
=
None
,
)
)
...
@@ -761,12 +765,12 @@ class SchedulerOutputProcessorMixin:
...
@@ -761,12 +765,12 @@ class SchedulerOutputProcessorMixin:
prompt_tokens
.
append
(
len
(
req
.
origin_input_ids
))
prompt_tokens
.
append
(
len
(
req
.
origin_input_ids
))
cached_tokens
.
append
(
req
.
cached_tokens
)
cached_tokens
.
append
(
req
.
cached_tokens
)
self
.
send_to_detokenizer
.
send_pyobj
(
self
.
send_to_detokenizer
.
send_pyobj
(
BatchEmbeddingOut
(
BatchEmbeddingOutput
(
rids
,
finished_reasons
,
finished_reasons
,
embeddings
,
embeddings
,
prompt_tokens
,
prompt_tokens
,
cached_tokens
,
cached_tokens
,
rids
=
rids
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_idx
=
None
,
placeholder_tokens_val
=
None
,
placeholder_tokens_val
=
None
,
)
)
...
...
python/sglang/srt/managers/tokenizer_communicator_mixin.py
View file @
3c699772
...
@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqOutput
,
DestroyWeightsUpdateGroupReqOutput
,
ExpertDistributionReq
,
ExpertDistributionReq
,
ExpertDistributionReqOutput
,
ExpertDistributionReqOutput
,
ExpertDistributionReqType
,
FlushCacheReqInput
,
FlushCacheReqInput
,
FlushCacheReqOutput
,
FlushCacheReqOutput
,
GetInternalStateReq
,
GetInternalStateReq
,
...
@@ -44,7 +45,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -44,7 +45,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqOutput
,
InitWeightsUpdateGroupReqOutput
,
LoadLoRAAdapterReqInput
,
LoadLoRAAdapterReqInput
,
LoadLoRAAdapterReqOutput
,
LoadLoRAAdapterReqOutput
,
LoRAUpdate
Resul
t
,
LoRAUpdate
Outpu
t
,
MultiTokenizerWrapper
,
MultiTokenizerWrapper
,
OpenSessionReqInput
,
OpenSessionReqInput
,
ProfileReq
,
ProfileReq
,
...
@@ -276,7 +277,7 @@ class TokenizerCommunicatorMixin:
...
@@ -276,7 +277,7 @@ class TokenizerCommunicatorMixin:
self
.
expert_distribution_communicator
.
handle_recv
,
self
.
expert_distribution_communicator
.
handle_recv
,
),
),
(
(
LoRAUpdate
Resul
t
,
LoRAUpdate
Outpu
t
,
self
.
update_lora_adapter_communicator
.
handle_recv
,
self
.
update_lora_adapter_communicator
.
handle_recv
,
),
),
(
(
...
@@ -335,15 +336,18 @@ class TokenizerCommunicatorMixin:
...
@@ -335,15 +336,18 @@ class TokenizerCommunicatorMixin:
async
def
start_expert_distribution_record
(
self
:
TokenizerManager
):
async
def
start_expert_distribution_record
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
()
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
START_RECORD
)
req
=
ExpertDistributionReq
(
action
=
ExpertDistributionReqType
.
START_RECORD
)
await
self
.
expert_distribution_communicator
(
req
)
async
def
stop_expert_distribution_record
(
self
:
TokenizerManager
):
async
def
stop_expert_distribution_record
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
()
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
STOP_RECORD
)
req
=
ExpertDistributionReq
(
action
=
ExpertDistributionReqType
.
STOP_RECORD
)
await
self
.
expert_distribution_communicator
(
req
)
async
def
dump_expert_distribution_record
(
self
:
TokenizerManager
):
async
def
dump_expert_distribution_record
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
()
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
DUMP_RECORD
)
req
=
ExpertDistributionReq
(
action
=
ExpertDistributionReqType
.
DUMP_RECORD
)
await
self
.
expert_distribution_communicator
(
req
)
async
def
init_weights_update_group
(
async
def
init_weights_update_group
(
self
:
TokenizerManager
,
self
:
TokenizerManager
,
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
3c699772
...
@@ -48,18 +48,17 @@ from sglang.srt.hf_transformers_utils import (
...
@@ -48,18 +48,17 @@ from sglang.srt.hf_transformers_utils import (
get_tokenizer
,
get_tokenizer
,
get_tokenizer_from_processor
,
get_tokenizer_from_processor
,
)
)
from
sglang.srt.lora.lora_registry
import
LoRARef
,
LoRARegistry
from
sglang.srt.lora.lora_registry
import
LoRARegistry
from
sglang.srt.managers.async_dynamic_batch_tokenizer
import
AsyncDynamicbatchTokenizer
from
sglang.srt.managers.async_dynamic_batch_tokenizer
import
AsyncDynamicbatchTokenizer
from
sglang.srt.managers.disagg_service
import
start_disagg_service
from
sglang.srt.managers.disagg_service
import
start_disagg_service
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
AbortReq
,
BatchEmbeddingOut
,
BatchEmbeddingOut
put
,
BatchMultimodalOut
,
BatchMultimodalOut
put
,
BatchStrOut
,
BatchStrOut
put
,
BatchTokenIDOut
,
BatchTokenIDOut
put
,
BatchTokenizedEmbeddingReqInput
,
BatchTokenizedEmbeddingReqInput
,
BatchTokenizedGenerateReqInput
,
BatchTokenizedGenerateReqInput
,
CloseSessionReqInput
,
ConfigureLoggingReq
,
ConfigureLoggingReq
,
EmbeddingReqInput
,
EmbeddingReqInput
,
FreezeGCReq
,
FreezeGCReq
,
...
@@ -67,7 +66,6 @@ from sglang.srt.managers.io_struct import (
...
@@ -67,7 +66,6 @@ from sglang.srt.managers.io_struct import (
GetLoadReqInput
,
GetLoadReqInput
,
HealthCheckOutput
,
HealthCheckOutput
,
MultiTokenizerWrapper
,
MultiTokenizerWrapper
,
OpenSessionReqInput
,
OpenSessionReqOutput
,
OpenSessionReqOutput
,
SessionParams
,
SessionParams
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
...
@@ -341,10 +339,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -341,10 +339,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
[
[
(
(
(
(
BatchStrOut
,
BatchStrOut
put
,
BatchEmbeddingOut
,
BatchEmbeddingOut
put
,
BatchTokenIDOut
,
BatchTokenIDOut
put
,
BatchMultimodalOut
,
BatchMultimodalOut
put
,
),
),
self
.
_handle_batch_output
,
self
.
_handle_batch_output
,
),
),
...
@@ -716,7 +714,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -716,7 +714,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
)
tokenized_obj
=
TokenizedGenerateReqInput
(
tokenized_obj
=
TokenizedGenerateReqInput
(
obj
.
rid
,
input_text
,
input_text
,
input_ids
,
input_ids
,
mm_inputs
,
mm_inputs
,
...
@@ -726,6 +723,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -726,6 +723,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
obj
.
top_logprobs_num
,
obj
.
top_logprobs_num
,
obj
.
token_ids_logprob
,
obj
.
token_ids_logprob
,
obj
.
stream
,
obj
.
stream
,
rid
=
obj
.
rid
,
bootstrap_host
=
obj
.
bootstrap_host
,
bootstrap_host
=
obj
.
bootstrap_host
,
bootstrap_port
=
obj
.
bootstrap_port
,
bootstrap_port
=
obj
.
bootstrap_port
,
bootstrap_room
=
obj
.
bootstrap_room
,
bootstrap_room
=
obj
.
bootstrap_room
,
...
@@ -740,12 +738,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -740,12 +738,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
)
elif
isinstance
(
obj
,
EmbeddingReqInput
):
elif
isinstance
(
obj
,
EmbeddingReqInput
):
tokenized_obj
=
TokenizedEmbeddingReqInput
(
tokenized_obj
=
TokenizedEmbeddingReqInput
(
obj
.
rid
,
input_text
,
input_text
,
input_ids
,
input_ids
,
mm_inputs
,
mm_inputs
,
token_type_ids
,
token_type_ids
,
sampling_params
,
sampling_params
,
rid
=
obj
.
rid
,
priority
=
obj
.
priority
,
priority
=
obj
.
priority
,
)
)
...
@@ -1038,7 +1036,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -1038,7 +1036,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
def
abort_request
(
self
,
rid
:
str
=
""
,
abort_all
:
bool
=
False
):
def
abort_request
(
self
,
rid
:
str
=
""
,
abort_all
:
bool
=
False
):
if
not
abort_all
and
rid
not
in
self
.
rid_to_state
:
if
not
abort_all
and
rid
not
in
self
.
rid_to_state
:
return
return
req
=
AbortReq
(
rid
,
abort_all
)
req
=
AbortReq
(
rid
=
rid
,
abort_all
=
abort_all
)
self
.
send_to_scheduler
.
send_pyobj
(
req
)
self
.
send_to_scheduler
.
send_pyobj
(
req
)
if
self
.
enable_metrics
:
if
self
.
enable_metrics
:
# TODO: also use custom_labels from the request
# TODO: also use custom_labels from the request
...
@@ -1303,7 +1301,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -1303,7 +1301,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
def
_handle_batch_output
(
def
_handle_batch_output
(
self
,
self
,
recv_obj
:
Union
[
recv_obj
:
Union
[
BatchStrOut
,
BatchEmbeddingOut
,
BatchMultimodalOut
,
BatchTokenIDOut
BatchStrOutput
,
BatchEmbeddingOutput
,
BatchMultimodalOutput
,
BatchTokenIDOutput
,
],
],
):
):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
...
@@ -1337,7 +1338,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -1337,7 +1338,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
i
,
i
,
)
)
if
not
isinstance
(
recv_obj
,
BatchEmbeddingOut
):
if
not
isinstance
(
recv_obj
,
BatchEmbeddingOut
put
):
meta_info
.
update
(
meta_info
.
update
(
{
{
"completion_tokens"
:
recv_obj
.
completion_tokens
[
i
],
"completion_tokens"
:
recv_obj
.
completion_tokens
[
i
],
...
@@ -1348,7 +1349,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -1348,7 +1349,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if
getattr
(
recv_obj
,
"output_hidden_states"
,
None
):
if
getattr
(
recv_obj
,
"output_hidden_states"
,
None
):
meta_info
[
"hidden_states"
]
=
recv_obj
.
output_hidden_states
[
i
]
meta_info
[
"hidden_states"
]
=
recv_obj
.
output_hidden_states
[
i
]
if
isinstance
(
recv_obj
,
BatchStrOut
):
if
isinstance
(
recv_obj
,
BatchStrOut
put
):
state
.
text
+=
recv_obj
.
output_strs
[
i
]
state
.
text
+=
recv_obj
.
output_strs
[
i
]
if
state
.
obj
.
stream
:
if
state
.
obj
.
stream
:
state
.
output_ids
.
extend
(
recv_obj
.
output_ids
[
i
])
state
.
output_ids
.
extend
(
recv_obj
.
output_ids
[
i
])
...
@@ -1363,7 +1364,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -1363,7 +1364,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"output_ids"
:
output_token_ids
,
"output_ids"
:
output_token_ids
,
"meta_info"
:
meta_info
,
"meta_info"
:
meta_info
,
}
}
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
):
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
put
):
if
self
.
server_args
.
stream_output
and
state
.
obj
.
stream
:
if
self
.
server_args
.
stream_output
and
state
.
obj
.
stream
:
state
.
output_ids
.
extend
(
recv_obj
.
output_ids
[
i
])
state
.
output_ids
.
extend
(
recv_obj
.
output_ids
[
i
])
output_token_ids
=
state
.
output_ids
[
state
.
last_output_offset
:]
output_token_ids
=
state
.
output_ids
[
state
.
last_output_offset
:]
...
@@ -1376,10 +1377,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -1376,10 +1377,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"output_ids"
:
output_token_ids
,
"output_ids"
:
output_token_ids
,
"meta_info"
:
meta_info
,
"meta_info"
:
meta_info
,
}
}
elif
isinstance
(
recv_obj
,
BatchMultimodalOut
):
elif
isinstance
(
recv_obj
,
BatchMultimodalOut
put
):
raise
NotImplementedError
(
"BatchMultimodalOut not implemented"
)
raise
NotImplementedError
(
"BatchMultimodalOut not implemented"
)
else
:
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
put
)
out_dict
=
{
out_dict
=
{
"embedding"
:
recv_obj
.
embeddings
[
i
],
"embedding"
:
recv_obj
.
embeddings
[
i
],
"meta_info"
:
meta_info
,
"meta_info"
:
meta_info
,
...
@@ -1418,7 +1419,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -1418,7 +1419,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
top_logprobs_num
:
int
,
top_logprobs_num
:
int
,
token_ids_logprob
:
List
[
int
],
token_ids_logprob
:
List
[
int
],
return_text_in_logprobs
:
bool
,
return_text_in_logprobs
:
bool
,
recv_obj
:
BatchStrOut
,
recv_obj
:
BatchStrOut
put
,
recv_obj_index
:
int
,
recv_obj_index
:
int
,
):
):
if
recv_obj
.
input_token_logprobs_val
is
None
:
if
recv_obj
.
input_token_logprobs_val
is
None
:
...
@@ -1536,7 +1537,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -1536,7 +1537,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
ret
.
append
(
None
)
ret
.
append
(
None
)
return
ret
return
ret
def
collect_metrics
(
self
,
state
:
ReqState
,
recv_obj
:
BatchStrOut
,
i
:
int
):
def
collect_metrics
(
self
,
state
:
ReqState
,
recv_obj
:
BatchStrOut
put
,
i
:
int
):
completion_tokens
=
(
completion_tokens
=
(
recv_obj
.
completion_tokens
[
i
]
recv_obj
.
completion_tokens
[
i
]
if
getattr
(
recv_obj
,
"completion_tokens"
,
None
)
if
getattr
(
recv_obj
,
"completion_tokens"
,
None
)
...
@@ -1632,7 +1633,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -1632,7 +1633,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
asyncio
.
create_task
(
asyncio
.
to_thread
(
background_task
))
asyncio
.
create_task
(
asyncio
.
to_thread
(
background_task
))
def
_handle_abort_req
(
self
,
recv_obj
):
def
_handle_abort_req
(
self
,
recv_obj
:
AbortReq
):
if
is_health_check_generate_req
(
recv_obj
):
if
is_health_check_generate_req
(
recv_obj
):
return
return
state
=
self
.
rid_to_state
[
recv_obj
.
rid
]
state
=
self
.
rid_to_state
[
recv_obj
.
rid
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment