Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e719bb0e
Unverified
Commit
e719bb0e
authored
Sep 07, 2025
by
Liangsheng Yin
Committed by
GitHub
Sep 07, 2025
Browse files
[1/2] Refactor multi-tokenizer manager (#10074)
parent
06724683
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
424 additions
and
488 deletions
+424
-488
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+24
-12
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+26
-37
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+4
-4
python/sglang/srt/managers/disagg_service.py
python/sglang/srt/managers/disagg_service.py
+46
-0
python/sglang/srt/managers/multi_tokenizer_mixin.py
python/sglang/srt/managers/multi_tokenizer_mixin.py
+318
-394
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+6
-41
No files found.
python/sglang/srt/entrypoints/engine.py
View file @
e719bb0e
...
@@ -704,6 +704,24 @@ def _set_envs_and_config(server_args: ServerArgs):
...
@@ -704,6 +704,24 @@ def _set_envs_and_config(server_args: ServerArgs):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
_init_tokenizer_manager
(
server_args
:
ServerArgs
,
port_args
:
PortArgs
)
->
TokenizerManager
:
# Launch tokenizer process
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
)
# Initialize templates
template_manager
=
TemplateManager
()
template_manager
.
initialize_templates
(
tokenizer_manager
=
tokenizer_manager
,
model_path
=
server_args
.
model_path
,
chat_template
=
server_args
.
chat_template
,
completion_template
=
server_args
.
completion_template
,
)
return
tokenizer_manager
,
template_manager
def
_launch_subprocesses
(
def
_launch_subprocesses
(
server_args
:
ServerArgs
,
port_args
:
Optional
[
PortArgs
]
=
None
server_args
:
ServerArgs
,
port_args
:
Optional
[
PortArgs
]
=
None
)
->
Tuple
[
TokenizerManager
,
TemplateManager
,
Dict
]:
)
->
Tuple
[
TokenizerManager
,
TemplateManager
,
Dict
]:
...
@@ -816,23 +834,15 @@ def _launch_subprocesses(
...
@@ -816,23 +834,15 @@ def _launch_subprocesses(
),
),
)
)
detoken_proc
.
start
()
detoken_proc
.
start
()
# Init tokenizer manager first, as the bootstrap server is initialized here
if
server_args
.
tokenizer_worker_num
>
1
:
if
server_args
.
tokenizer_worker_num
>
1
:
# Launch multi-tokenizer router
# Launch multi-tokenizer router
tokenizer_manager
=
MultiTokenizerRouter
(
server_args
,
port_args
)
tokenizer_manager
=
MultiTokenizerRouter
(
server_args
,
port_args
)
# Initialize templates
template_manager
=
None
template_manager
=
None
else
:
else
:
# Launch tokenizer process
tokenizer_manager
,
template_manager
=
_init_tokenizer_manager
(
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
)
server_args
,
port_args
# Initialize templates
template_manager
=
TemplateManager
()
template_manager
.
initialize_templates
(
tokenizer_manager
=
tokenizer_manager
,
model_path
=
server_args
.
model_path
,
chat_template
=
server_args
.
chat_template
,
completion_template
=
server_args
.
completion_template
,
)
)
# Wait for the model to finish loading
# Wait for the model to finish loading
...
@@ -856,5 +866,7 @@ def _launch_subprocesses(
...
@@ -856,5 +866,7 @@ def _launch_subprocesses(
# Assume all schedulers have the same scheduler_info
# Assume all schedulers have the same scheduler_info
scheduler_info
=
scheduler_infos
[
0
]
scheduler_info
=
scheduler_infos
[
0
]
tokenizer_manager
.
max_req_input_len
=
scheduler_info
[
"max_req_input_len"
]
tokenizer_manager
.
max_req_input_len
=
scheduler_info
[
"max_req_input_len"
]
return
tokenizer_manager
,
template_manager
,
scheduler_info
return
tokenizer_manager
,
template_manager
,
scheduler_info
python/sglang/srt/entrypoints/http_server.py
View file @
e719bb0e
...
@@ -92,7 +92,6 @@ from sglang.srt.managers.io_struct import (
...
@@ -92,7 +92,6 @@ from sglang.srt.managers.io_struct import (
)
)
from
sglang.srt.managers.multi_tokenizer_mixin
import
(
from
sglang.srt.managers.multi_tokenizer_mixin
import
(
MultiTokenizerManager
,
MultiTokenizerManager
,
deserialize_data
,
get_main_process_id
,
get_main_process_id
,
read_from_shared_memory
,
read_from_shared_memory
,
write_data_for_multi_tokenizer
,
write_data_for_multi_tokenizer
,
...
@@ -136,21 +135,6 @@ def set_global_state(global_state: _GlobalState):
...
@@ -136,21 +135,6 @@ def set_global_state(global_state: _GlobalState):
_global_state
=
global_state
_global_state
=
global_state
# Function to set up all middlewares for multi-tokenizer compatibility
def
setup_middlewares
(
api_key
:
Optional
[
str
],
enable_metrics
:
bool
):
"""Setup all middlewares for both single and multi-process modes"""
worker_pid
=
os
.
getpid
()
if
api_key
:
add_api_key_middleware
(
app
,
api_key
)
logger
.
info
(
f
"Worker
{
worker_pid
}
added API key middleware"
)
if
enable_metrics
:
add_prometheus_middleware
(
app
)
enable_func_timer
()
logger
.
info
(
f
"Worker
{
worker_pid
}
added prometheus middleware"
)
async
def
init_multi_tokenizer
()
->
ServerArgs
:
async
def
init_multi_tokenizer
()
->
ServerArgs
:
"""Read args information from shm and init tokenizer manager for current process"""
"""Read args information from shm and init tokenizer manager for current process"""
pid
=
os
.
getpid
()
pid
=
os
.
getpid
()
...
@@ -158,11 +142,15 @@ async def init_multi_tokenizer() -> ServerArgs:
...
@@ -158,11 +142,15 @@ async def init_multi_tokenizer() -> ServerArgs:
logger
.
info
(
f
"current worker_id:
{
pid
}
, main processID:
{
main_pid
}
"
)
logger
.
info
(
f
"current worker_id:
{
pid
}
, main processID:
{
main_pid
}
"
)
# Read configuration from shared memory
# Read configuration from shared memory
port_args_data
=
read_from_shared_memory
(
f
"port_args_
{
main_pid
}
"
)
port_args
,
server_args
,
scheduler_info
=
read_from_shared_memory
(
server_args_data
=
read_from_shared_memory
(
f
"server_args_
{
main_pid
}
"
)
f
"multi_tokenizer_args_
{
main_pid
}
"
scheduler_info_data
=
read_from_shared_memory
(
f
"scheduler_info_
{
main_pid
}
"
)
)
port_args
,
server_args
=
deserialize_data
(
port_args_data
,
server_args_data
)
server_args
:
ServerArgs
scheduler_info
=
scheduler_info_data
# API key authentication is not supported in multi-tokenizer mode
assert
(
server_args
.
api_key
is
None
),
"API key is not supported in multi-tokenizer mode"
port_args
.
tokenizer_ipc_name
=
(
port_args
.
tokenizer_ipc_name
=
(
f
"ipc://
{
tempfile
.
NamedTemporaryFile
(
delete
=
False
).
name
}
"
f
"ipc://
{
tempfile
.
NamedTemporaryFile
(
delete
=
False
).
name
}
"
...
@@ -193,13 +181,17 @@ async def init_multi_tokenizer() -> ServerArgs:
...
@@ -193,13 +181,17 @@ async def init_multi_tokenizer() -> ServerArgs:
@
asynccontextmanager
@
asynccontextmanager
async
def
lifespan
(
fast_api_app
:
FastAPI
):
async
def
lifespan
(
fast_api_app
:
FastAPI
):
server_args
=
getattr
(
fast_api_app
,
"server_args"
,
None
)
if
not
getattr
(
fast_api_app
,
"is_single_tokenizer_mode"
,
False
):
if
server_args
is
None
:
# Initialize multi-tokenizer support for worker processes
# Initialize multi-tokenizer support for worker processes
fast_api_app
.
server_args
=
await
init_multi_tokenizer
()
fast_api_app
.
server_args
:
ServerArgs
=
await
init_multi_tokenizer
()
setup_middlewares
(
fast_api_app
.
server_args
.
api_key
,
fast_api_app
.
server_args
.
enable_metrics
# only metrics middleware is supported in multi-tokenizer mode
)
worker_pid
=
os
.
getpid
()
if
fast_api_app
.
server_args
.
enable_metrics
:
add_prometheus_middleware
(
app
)
enable_func_timer
()
logger
.
info
(
f
"Worker
{
worker_pid
}
added prometheus middleware"
)
fast_api_app
.
warmup_thread
=
threading
.
Thread
(
fast_api_app
.
warmup_thread
=
threading
.
Thread
(
target
=
_wait_and_warmup
,
target
=
_wait_and_warmup
,
args
=
(
args
=
(
...
@@ -1187,13 +1179,11 @@ def launch_server(
...
@@ -1187,13 +1179,11 @@ def launch_server(
)
)
if
server_args
.
tokenizer_worker_num
>
1
:
if
server_args
.
tokenizer_worker_num
>
1
:
port_args_shm
,
server_args_shm
,
scheduler_info_shm
=
(
multi_tokenizer_args_shm
=
write_data_for_multi_tokenizer
(
write_data_for_multi_tokenizer
(
port_args
,
port_args
,
server_args
,
server_args
,
scheduler_info
,
scheduler_info
,
)
)
)
else
:
else
:
# Add api key authorization
# Add api key authorization
if
server_args
.
api_key
:
if
server_args
.
api_key
:
...
@@ -1239,6 +1229,7 @@ def launch_server(
...
@@ -1239,6 +1229,7 @@ def launch_server(
workers
=
server_args
.
tokenizer_worker_num
,
workers
=
server_args
.
tokenizer_worker_num
,
)
)
else
:
else
:
app
.
is_single_tokenizer_mode
=
True
uvicorn
.
run
(
uvicorn
.
run
(
app
,
app
,
host
=
server_args
.
host
,
host
=
server_args
.
host
,
...
@@ -1249,10 +1240,8 @@ def launch_server(
...
@@ -1249,10 +1240,8 @@ def launch_server(
)
)
finally
:
finally
:
if
server_args
.
tokenizer_worker_num
>
1
:
if
server_args
.
tokenizer_worker_num
>
1
:
port_args_shm
.
unlink
()
multi_tokenizer_args_shm
.
unlink
()
server_args_shm
.
unlink
()
_global_state
.
tokenizer_manager
.
socket_mapping
.
clear_all_sockets
()
scheduler_info_shm
.
unlink
()
_global_state
.
tokenizer_manager
.
clear_tokenizer_mapping
()
else
:
else
:
warmup_thread
.
join
()
warmup_thread
.
join
()
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
e719bb0e
...
@@ -34,7 +34,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -34,7 +34,7 @@ from sglang.srt.managers.io_struct import (
FreezeGCReq
,
FreezeGCReq
,
MultiTokenizerRegisterReq
,
MultiTokenizerRegisterReq
,
)
)
from
sglang.srt.managers.multi_tokenizer_mixin
import
Multi
T
okenizerMixin
from
sglang.srt.managers.multi_tokenizer_mixin
import
Multi
HttpWorkerDet
okenizerMixin
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
configure_logger
,
configure_logger
,
...
@@ -69,7 +69,7 @@ class DecodeStatus:
...
@@ -69,7 +69,7 @@ class DecodeStatus:
sent_offset
:
int
=
0
sent_offset
:
int
=
0
class
DetokenizerManager
(
Multi
T
okenizerMixin
):
class
DetokenizerManager
(
Multi
HttpWorkerDet
okenizerMixin
):
"""DetokenizerManager is a process that detokenizes the token ids."""
"""DetokenizerManager is a process that detokenizes the token ids."""
def
__init__
(
def
__init__
(
...
@@ -289,11 +289,11 @@ def run_detokenizer_process(
...
@@ -289,11 +289,11 @@ def run_detokenizer_process(
try
:
try
:
manager
=
DetokenizerManager
(
server_args
,
port_args
)
manager
=
DetokenizerManager
(
server_args
,
port_args
)
if
server_args
.
tokenizer_worker_num
>
1
:
if
server_args
.
tokenizer_worker_num
>
1
:
manager
.
multi_
tokenizer_manag
er_event_loop
()
manager
.
multi_
http_work
er_event_loop
()
else
:
else
:
manager
.
event_loop
()
manager
.
event_loop
()
except
Exception
:
except
Exception
:
manager
.
clear_tokenizer_mapping
()
manager
.
socket_mapping
.
clear_all_sockets
()
traceback
=
get_exception_traceback
()
traceback
=
get_exception_traceback
()
logger
.
error
(
f
"DetokenizerManager hit an exception:
{
traceback
}
"
)
logger
.
error
(
f
"DetokenizerManager hit an exception:
{
traceback
}
"
)
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
python/sglang/srt/managers/disagg_service.py
0 → 100644
View file @
e719bb0e
"""Start bootstrap/kv-store-related server"""
import
os
from
typing
import
Type
from
sglang.srt.disaggregation.base
import
BaseKVBootstrapServer
from
sglang.srt.disaggregation.utils
import
(
DisaggregationMode
,
KVClassType
,
TransferBackend
,
get_kv_class
,
)
from
sglang.srt.server_args
import
ServerArgs
def
start_disagg_service
(
server_args
:
ServerArgs
,
):
# Start kv boostrap server on prefill
disagg_mode
=
DisaggregationMode
(
server_args
.
disaggregation_mode
)
transfer_backend
=
TransferBackend
(
server_args
.
disaggregation_transfer_backend
)
if
disagg_mode
==
DisaggregationMode
.
PREFILL
:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class
:
Type
[
BaseKVBootstrapServer
]
=
get_kv_class
(
transfer_backend
,
KVClassType
.
BOOTSTRAP_SERVER
)
bootstrap_server
:
BaseKVBootstrapServer
=
kv_bootstrap_server_class
(
host
=
server_args
.
host
,
port
=
server_args
.
disaggregation_bootstrap_port
,
)
is_create_store
=
(
server_args
.
node_rank
==
0
and
transfer_backend
==
TransferBackend
.
ASCEND
)
if
is_create_store
:
try
:
from
mf_adapter
import
create_config_store
ascend_url
=
os
.
getenv
(
"ASCEND_MF_STORE_URL"
)
create_config_store
(
ascend_url
)
except
Exception
as
e
:
error_message
=
f
"Failed create mf store, invalid ascend_url."
error_message
+=
f
" With exception
{
e
}
"
raise
error_message
return
bootstrap_server
python/sglang/srt/managers/multi_tokenizer_mixin.py
View file @
e719bb0e
...
@@ -13,21 +13,21 @@
...
@@ -13,21 +13,21 @@
# ==============================================================================
# ==============================================================================
"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
import
asyncio
import
asyncio
import
dataclasses
import
json
import
logging
import
logging
import
multiprocessing
as
multiprocessing
import
multiprocessing
as
multiprocessing
import
os
import
os
import
pickle
import
sys
import
sys
import
threading
import
threading
from
multiprocessing
import
shared_memory
from
multiprocessing
import
shared_memory
from
typing
import
Dict
from
typing
import
Any
,
Dict
import
setproctitle
import
setproctitle
import
zmq
import
zmq
import
zmq.asyncio
import
zmq.asyncio
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
,
TransferBackend
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
,
TransferBackend
from
sglang.srt.managers.disagg_service
import
start_disagg_service
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
BatchEmbeddingOut
,
BatchEmbeddingOut
,
BatchMultimodalOut
,
BatchMultimodalOut
,
...
@@ -44,44 +44,42 @@ from sglang.utils import get_exception_traceback
...
@@ -44,44 +44,42 @@ from sglang.utils import get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
MultiTokenizerMixin
:
class
SocketMapping
:
"""Mixin class for MultiTokenizerManager and DetokenizerManager"""
def
__init__
(
self
):
def
create_sockets_mapping
(
self
):
if
not
hasattr
(
self
,
"tokenizer_mapping"
):
self
.
tokenizer_mapping
=
{}
# Create ZMQ context if needed
if
not
hasattr
(
self
,
"_zmq_context"
):
self
.
_zmq_context
=
zmq
.
Context
()
self
.
_zmq_context
=
zmq
.
Context
()
self
.
_mapping
:
Dict
[
str
,
zmq
.
Socket
]
=
{}
def
clear_all_sockets
(
self
):
for
socket
in
self
.
_mapping
.
values
():
socket
.
close
()
self
.
_mapping
.
clear
()
def
init_tokenizer
_mapping
(
def
register_ipc
_mapping
(
self
,
recv_obj
:
MultiTokenizerRegisterReq
,
worker_id
:
str
self
,
recv_obj
:
MultiTokenizerRegisterReq
,
worker_id
:
str
,
is_tokenizer
:
bool
):
):
"""init tokenizer mapping from register request"""
type_str
=
"tokenizer"
if
is_tokenizer
else
"detokenizer"
ipc_name
=
recv_obj
.
ipc_name
if
worker_id
in
self
.
_mapping
:
worker_id_int
=
int
(
worker_id
)
logger
.
warning
(
f
"
{
type_str
}
already registered with worker
{
worker_id
}
, skipping..."
if
worker_id_int
not
in
self
.
tokenizer_mapping
:
)
socket
=
get_zmq_socket
(
self
.
_zmq_context
,
zmq
.
PUSH
,
ipc_name
,
False
)
return
self
.
tokenizer_mapping
[
worker_id_int
]
=
socket
logger
.
info
(
self
.
tokenizer_mapping
[
worker_id_int
].
send_pyobj
(
recv_obj
)
f
"
{
type_str
}
not registered with worker
{
worker_id
}
, registering..."
return
True
)
else
:
socket
=
get_zmq_socket
(
self
.
_zmq_context
,
zmq
.
PUSH
,
recv_obj
.
ipc_name
,
False
)
return
False
self
.
_mapping
[
worker_id
]
=
socket
self
.
_mapping
[
worker_id
].
send_pyobj
(
recv_obj
)
def
register_tokenizer_ipc
(
self
,
recv_obj
,
worker_id
):
def
send_output
(
self
,
worker_id
:
str
,
output
:
Any
):
if
worker_id
not
in
self
.
tokenizer_mapping
:
if
worker_id
not
in
self
.
_mapping
:
# register the worker if not already done
if
isinstance
(
recv_obj
,
MultiTokenizerRegisterReq
):
return
self
.
init_tokenizer_mapping
(
recv_obj
,
worker_id
)
else
:
logger
.
error
(
logger
.
error
(
f
"Worker
{
worker_id
}
not registered and not found in tokenizer mapping . "
f
"worker ID
{
worker_id
}
not registered. Check if the server Process is alive"
"Please ensure the worker is registered correctly."
)
)
return
False
return
self
.
_mapping
[
worker_id
].
send_pyobj
(
output
)
def
_handle_output_by_index
(
self
,
output
,
i
):
def
_handle_output_by_index
(
output
,
i
):
"""NOTE: A maintainable method is better here."""
"""NOTE: A maintainable method is better here."""
if
isinstance
(
output
,
BatchTokenIDOut
):
if
isinstance
(
output
,
BatchTokenIDOut
):
new_output
=
BatchTokenIDOut
(
new_output
=
BatchTokenIDOut
(
...
@@ -94,9 +92,7 @@ class MultiTokenizerMixin:
...
@@ -94,9 +92,7 @@ class MultiTokenizerMixin:
decoded_texts
=
(
decoded_texts
=
(
[
output
.
decoded_texts
[
i
]]
if
len
(
output
.
decoded_texts
)
>
i
else
None
[
output
.
decoded_texts
[
i
]]
if
len
(
output
.
decoded_texts
)
>
i
else
None
),
),
decode_ids
=
(
decode_ids
=
([
output
.
decode_ids
[
i
]]
if
len
(
output
.
decode_ids
)
>
i
else
None
),
[
output
.
decode_ids
[
i
]]
if
len
(
output
.
decode_ids
)
>
i
else
None
),
read_offsets
=
(
read_offsets
=
(
[
output
.
read_offsets
[
i
]]
if
len
(
output
.
read_offsets
)
>
i
else
None
[
output
.
read_offsets
[
i
]]
if
len
(
output
.
read_offsets
)
>
i
else
None
),
),
...
@@ -130,9 +126,7 @@ class MultiTokenizerMixin:
...
@@ -130,9 +126,7 @@ class MultiTokenizerMixin:
[
output
.
cached_tokens
[
i
]]
if
len
(
output
.
cached_tokens
)
>
i
else
None
[
output
.
cached_tokens
[
i
]]
if
len
(
output
.
cached_tokens
)
>
i
else
None
),
),
spec_verify_ct
=
(
spec_verify_ct
=
(
[
output
.
spec_verify_ct
[
i
]]
[
output
.
spec_verify_ct
[
i
]]
if
len
(
output
.
spec_verify_ct
)
>
i
else
None
if
len
(
output
.
spec_verify_ct
)
>
i
else
None
),
),
input_token_logprobs_val
=
(
input_token_logprobs_val
=
(
[
output
.
input_token_logprobs_val
[
i
]]
[
output
.
input_token_logprobs_val
[
i
]]
...
@@ -208,9 +202,7 @@ class MultiTokenizerMixin:
...
@@ -208,9 +202,7 @@ class MultiTokenizerMixin:
if
len
(
output
.
finished_reasons
)
>
i
if
len
(
output
.
finished_reasons
)
>
i
else
None
else
None
),
),
embeddings
=
(
embeddings
=
([
output
.
embeddings
[
i
]]
if
len
(
output
.
embeddings
)
>
i
else
None
),
[
output
.
embeddings
[
i
]]
if
len
(
output
.
embeddings
)
>
i
else
None
),
prompt_tokens
=
(
prompt_tokens
=
(
[
output
.
prompt_tokens
[
i
]]
if
len
(
output
.
prompt_tokens
)
>
i
else
None
[
output
.
prompt_tokens
[
i
]]
if
len
(
output
.
prompt_tokens
)
>
i
else
None
),
),
...
@@ -246,9 +238,7 @@ class MultiTokenizerMixin:
...
@@ -246,9 +238,7 @@ class MultiTokenizerMixin:
[
output
.
cached_tokens
[
i
]]
if
len
(
output
.
cached_tokens
)
>
i
else
None
[
output
.
cached_tokens
[
i
]]
if
len
(
output
.
cached_tokens
)
>
i
else
None
),
),
spec_verify_ct
=
(
spec_verify_ct
=
(
[
output
.
spec_verify_ct
[
i
]]
[
output
.
spec_verify_ct
[
i
]]
if
len
(
output
.
spec_verify_ct
)
>
i
else
None
if
len
(
output
.
spec_verify_ct
)
>
i
else
None
),
),
input_token_logprobs_val
=
(
input_token_logprobs_val
=
(
[
output
.
input_token_logprobs_val
[
i
]]
[
output
.
input_token_logprobs_val
[
i
]]
...
@@ -341,6 +331,10 @@ class MultiTokenizerMixin:
...
@@ -341,6 +331,10 @@ class MultiTokenizerMixin:
new_output
=
output
new_output
=
output
return
new_output
return
new_output
class
MultiHttpWorkerDetokenizerMixin
:
"""Mixin class for MultiTokenizerManager and DetokenizerManager"""
def
get_worker_ids_from_req_rids
(
self
,
rids
):
def
get_worker_ids_from_req_rids
(
self
,
rids
):
if
isinstance
(
rids
,
list
):
if
isinstance
(
rids
,
list
):
worker_ids
=
[
int
(
rid
.
split
(
"_"
)[
0
])
for
rid
in
rids
]
worker_ids
=
[
int
(
rid
.
split
(
"_"
)[
0
])
for
rid
in
rids
]
...
@@ -350,9 +344,9 @@ class MultiTokenizerMixin:
...
@@ -350,9 +344,9 @@ class MultiTokenizerMixin:
worker_ids
=
[]
worker_ids
=
[]
return
worker_ids
return
worker_ids
def
multi_
tokenizer_manag
er_event_loop
(
self
):
def
multi_
http_work
er_event_loop
(
self
):
"""The event loop that handles requests, for multi
tokenizer manag
er mode
only
"""
"""The event loop that handles requests, for multi
multi-http-work
er mode"""
self
.
create_s
ocket
s_m
apping
()
self
.
socket_mapping
=
S
ocket
M
apping
()
while
True
:
while
True
:
recv_obj
=
self
.
recv_from_scheduler
.
recv_pyobj
()
recv_obj
=
self
.
recv_from_scheduler
.
recv_pyobj
()
output
=
self
.
_request_dispatcher
(
recv_obj
)
output
=
self
.
_request_dispatcher
(
recv_obj
)
...
@@ -369,31 +363,15 @@ class MultiTokenizerMixin:
...
@@ -369,31 +363,15 @@ class MultiTokenizerMixin:
# Send data using the corresponding socket
# Send data using the corresponding socket
for
i
,
worker_id
in
enumerate
(
worker_ids
):
for
i
,
worker_id
in
enumerate
(
worker_ids
):
if
isinstance
(
recv_obj
,
MultiTokenizerRegisterReq
):
if
isinstance
(
recv_obj
,
MultiTokenizerRegisterReq
):
if
self
.
register_tokenizer_ipc
(
recv_obj
,
worker_id
):
self
.
socket_mapping
.
register_ipc_mapping
(
logger
.
info
(
recv_obj
,
worker_id
,
is_tokenizer
=
False
f
"DetokenizerManager Created ZMQ socket for worker
{
worker_id
}
"
)
)
continue
else
:
else
:
if
worker_id
not
in
self
.
tokenizer_mapping
:
new_output
=
_handle_output_by_index
(
output
,
i
)
logger
.
error
(
self
.
socket_mapping
.
send_output
(
worker_id
,
new_output
)
f
"Tokenizer Worker ID
{
worker_id
}
not registered. Check if the server Process
{
worker_id
}
is alive"
)
continue
new_output
=
self
.
_handle_output_by_index
(
output
,
i
)
self
.
tokenizer_mapping
[
worker_id
].
send_pyobj
(
new_output
)
def
clear_tokenizer_mapping
(
self
):
if
hasattr
(
self
,
"tokenizer_mapping"
):
for
socket
in
self
.
tokenizer_mapping
.
values
():
try
:
socket
.
close
()
except
Exception
as
e
:
logger
.
warning
(
f
"Failed to close socket:
{
e
}
"
)
self
.
tokenizer_mapping
.
clear
()
class
MultiTokenizerRouter
(
TokenizerManager
,
MultiTokenizerMixin
)
:
class
MultiTokenizerRouter
:
"""A router to receive requests from MultiTokenizerManager"""
"""A router to receive requests from MultiTokenizerManager"""
def
__init__
(
def
__init__
(
...
@@ -422,7 +400,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
...
@@ -422,7 +400,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
self
.
_handle_task
=
asyncio
.
run_coroutine_threadsafe
(
self
.
_handle_task
=
asyncio
.
run_coroutine_threadsafe
(
print_exception_wrapper
(
self
.
handle_loop
),
self
.
_loop
print_exception_wrapper
(
self
.
handle_loop
),
self
.
_loop
)
)
self
.
init_
disaggregation
(
)
self
.
disaggregation
_bootstrap_server
=
start_disagg_service
(
self
.
server_args
)
def
_run_loop
(
self
):
def
_run_loop
(
self
):
self
.
_loop
.
run_forever
()
self
.
_loop
.
run_forever
()
...
@@ -434,7 +412,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
...
@@ -434,7 +412,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
async
def
handle_loop
(
self
):
async
def
handle_loop
(
self
):
# special reqs will recv from scheduler, need to route to right worker
# special reqs will recv from scheduler, need to route to right worker
self
.
create_s
ocket
s_m
apping
()
self
.
socket_mapping
=
S
ocket
M
apping
()
while
True
:
while
True
:
recv_obj
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
recv_obj
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
await
self
.
_distribute_result_to_workers
(
recv_obj
)
await
self
.
_distribute_result_to_workers
(
recv_obj
)
...
@@ -454,22 +432,15 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
...
@@ -454,22 +432,15 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
# Distribute result to each worker
# Distribute result to each worker
for
i
,
worker_id
in
enumerate
(
worker_ids
):
for
i
,
worker_id
in
enumerate
(
worker_ids
):
if
isinstance
(
recv_obj
,
MultiTokenizerRegisterReq
):
if
isinstance
(
recv_obj
,
MultiTokenizerRegisterReq
):
if
self
.
register_tokenizer_ipc
(
recv_obj
,
worker_id
):
self
.
socket_mapping
.
register_ipc_mapping
(
logger
.
info
(
recv_obj
,
worker_id
,
is_tokenizer
=
True
f
"MultiTokenizerRouter Created ZMQ socket for worker
{
worker_id
}
"
)
)
continue
else
:
else
:
if
worker_id
not
in
self
.
tokenizer_mapping
:
new_recv_obj
=
_handle_output_by_index
(
recv_obj
,
i
)
logger
.
error
(
self
.
socket_mapping
.
send_output
(
worker_id
,
new_recv_obj
)
f
"Tokenizer Worker ID
{
worker_id
}
not registered. Check if the server Process
{
worker_id
}
is alive"
)
continue
new_recv_obj
=
self
.
_handle_output_by_index
(
recv_obj
,
i
)
self
.
tokenizer_mapping
[
worker_id
].
send_pyobj
(
new_recv_obj
)
class
MultiTokenizerManager
(
TokenizerManager
,
MultiTokenizerMixin
):
class
MultiTokenizerManager
(
TokenizerManager
):
"""Multi Process Tokenizer Manager that tokenizes the text."""
"""Multi Process Tokenizer Manager that tokenizes the text."""
def
__init__
(
def
__init__
(
...
@@ -535,42 +506,14 @@ async def print_exception_wrapper(func):
...
@@ -535,42 +506,14 @@ async def print_exception_wrapper(func):
sys
.
exit
(
1
)
sys
.
exit
(
1
)
def
serialize_port_args
(
port_args
:
PortArgs
)
->
dict
:
def
get_main_process_id
()
->
int
:
"""Serialize PortArgs into a shareable dictionary"""
"""Get the main process ID"""
return
{
return
multiprocessing
.
current_process
().
_parent_pid
"tokenizer_ipc_name"
:
port_args
.
tokenizer_ipc_name
,
"scheduler_input_ipc_name"
:
port_args
.
scheduler_input_ipc_name
,
"detokenizer_ipc_name"
:
port_args
.
detokenizer_ipc_name
,
"nccl_port"
:
port_args
.
nccl_port
,
"rpc_ipc_name"
:
port_args
.
rpc_ipc_name
,
"metrics_ipc_name"
:
port_args
.
metrics_ipc_name
,
"tokenizer_worker_ipc_name"
:
port_args
.
tokenizer_worker_ipc_name
,
}
def
deserialize_data
(
port_args
:
dict
,
server_args
:
dict
):
"""Deserialize data from shared dictionaries"""
return
PortArgs
(
**
port_args
),
ServerArgs
(
**
server_args
)
def
serialize_server_args
(
server_args
:
ServerArgs
)
->
dict
:
"""Serialize ServerArgs into a shareable dictionary"""
return
dataclasses
.
asdict
(
server_args
)
def
serialize_scheduler_info
(
scheduler_info
:
Dict
)
->
dict
:
"""Serialize scheduler_info into a shareable dictionary"""
return
scheduler_info
def
deserialize_scheduler_info
(
data
:
dict
)
->
Dict
:
"""Deserialize scheduler_info from a shared dictionary"""
return
data
def
write_to_shared_memory
(
data
:
dict
,
name
:
str
)
->
shared_memory
.
SharedMemory
:
def
write_to_shared_memory
(
obj
,
name
:
str
)
->
shared_memory
.
SharedMemory
:
"""Write data to shared memory"""
"""Write data to shared memory"""
serialized
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
serialized
=
pickle
.
dumps
(
obj
)
size
=
len
(
serialized
)
size
=
len
(
serialized
)
try
:
try
:
# Try to open existing shared memory
# Try to open existing shared memory
...
@@ -588,22 +531,17 @@ def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
...
@@ -588,22 +531,17 @@ def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
return
shm
return
shm
def
read_from_shared_memory
(
name
:
str
)
->
dict
:
def
read_from_shared_memory
(
name
:
str
)
->
Any
:
"""Read data from shared memory"""
"""Read data from shared memory"""
try
:
try
:
shm
=
shared_memory
.
SharedMemory
(
name
=
name
)
shm
=
shared_memory
.
SharedMemory
(
name
=
name
)
data
=
json
.
loads
(
bytes
(
shm
.
buf
)
.
decode
(
"utf-8"
)
)
data
=
pickle
.
loads
(
bytes
(
shm
.
buf
))
shm
.
close
()
shm
.
close
()
return
data
return
data
except
FileNotFoundError
:
except
FileNotFoundError
:
raise
FileNotFoundError
(
f
"Shared memory
{
name
}
not found"
)
raise
FileNotFoundError
(
f
"Shared memory
{
name
}
not found"
)
def
get_main_process_id
()
->
int
:
"""Get the main process ID"""
return
multiprocessing
.
current_process
().
_parent_pid
def
write_data_for_multi_tokenizer
(
def
write_data_for_multi_tokenizer
(
port_args
:
PortArgs
,
server_args
:
ServerArgs
,
scheduler_info
:
Dict
port_args
:
PortArgs
,
server_args
:
ServerArgs
,
scheduler_info
:
Dict
):
):
...
@@ -612,22 +550,8 @@ def write_data_for_multi_tokenizer(
...
@@ -612,22 +550,8 @@ def write_data_for_multi_tokenizer(
main_pid
=
get_main_process_id
()
main_pid
=
get_main_process_id
()
current_pid
=
os
.
getpid
()
current_pid
=
os
.
getpid
()
logger
.
info
(
f
"main process ID:
{
main_pid
}
, current process ID:
{
current_pid
}
"
)
logger
.
info
(
f
"main process ID:
{
main_pid
}
, current process ID:
{
current_pid
}
"
)
args
=
(
port_args
,
server_args
,
scheduler_info
)
args_shm
=
write_to_shared_memory
(
args
,
f
"multi_tokenizer_args_
{
current_pid
}
"
)
args_shm
.
close
()
# Write port_args to shared memory
return
args_shm
port_args_shm
=
write_to_shared_memory
(
serialize_port_args
(
port_args
),
f
"port_args_
{
current_pid
}
"
)
# Write server_args to shared memory
server_args_shm
=
write_to_shared_memory
(
serialize_server_args
(
server_args
),
f
"server_args_
{
current_pid
}
"
)
# Write scheduler_info to shared memory
scheduler_info_shm
=
write_to_shared_memory
(
serialize_scheduler_info
(
scheduler_info
),
f
"scheduler_info_
{
current_pid
}
"
)
port_args_shm
.
close
()
server_args_shm
.
close
()
scheduler_info_shm
.
close
()
return
port_args_shm
,
server_args_shm
,
scheduler_info_shm
python/sglang/srt/managers/tokenizer_manager.py
View file @
e719bb0e
...
@@ -54,19 +54,14 @@ from fastapi import BackgroundTasks
...
@@ -54,19 +54,14 @@ from fastapi import BackgroundTasks
from
sglang.srt.aio_rwlock
import
RWLock
from
sglang.srt.aio_rwlock
import
RWLock
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.disaggregation.base
import
BaseKVBootstrapServer
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.disaggregation.utils
import
(
DisaggregationMode
,
KVClassType
,
TransferBackend
,
get_kv_class
,
)
from
sglang.srt.hf_transformers_utils
import
(
from
sglang.srt.hf_transformers_utils
import
(
get_processor
,
get_processor
,
get_tokenizer
,
get_tokenizer
,
get_tokenizer_from_processor
,
get_tokenizer_from_processor
,
)
)
from
sglang.srt.lora.lora_registry
import
LoRARef
,
LoRARegistry
from
sglang.srt.lora.lora_registry
import
LoRARef
,
LoRARegistry
from
sglang.srt.managers.disagg_service
import
start_disagg_service
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
AbortReq
,
BatchEmbeddingOut
,
BatchEmbeddingOut
,
...
@@ -321,8 +316,10 @@ class TokenizerManager:
...
@@ -321,8 +316,10 @@ class TokenizerManager:
# LoRA updates and inference to overlap.
# LoRA updates and inference to overlap.
self
.
lora_update_lock
=
asyncio
.
Lock
()
self
.
lora_update_lock
=
asyncio
.
Lock
()
# For PD disaggregtion
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
init_disaggregation
()
self
.
server_args
.
disaggregation_mode
)
self
.
bootstrap_server
=
start_disagg_service
(
self
.
server_args
)
# For load balancing
# For load balancing
self
.
current_load
=
0
self
.
current_load
=
0
...
@@ -471,38 +468,6 @@ class TokenizerManager:
...
@@ -471,38 +468,6 @@ class TokenizerManager:
]
]
)
)
def
init_disaggregation
(
self
):
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
)
self
.
disaggregation_transfer_backend
=
TransferBackend
(
self
.
server_args
.
disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class
:
Type
[
BaseKVBootstrapServer
]
=
get_kv_class
(
self
.
disaggregation_transfer_backend
,
KVClassType
.
BOOTSTRAP_SERVER
)
self
.
bootstrap_server
:
BaseKVBootstrapServer
=
kv_bootstrap_server_class
(
host
=
self
.
server_args
.
host
,
port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
)
is_create_store
=
(
self
.
server_args
.
node_rank
==
0
and
self
.
server_args
.
disaggregation_transfer_backend
==
"ascend"
)
if
is_create_store
:
try
:
from
mf_adapter
import
create_config_store
ascend_url
=
os
.
getenv
(
"ASCEND_MF_STORE_URL"
)
create_config_store
(
ascend_url
)
except
Exception
as
e
:
error_message
=
f
"Failed create mf store, invalid ascend_url."
error_message
+=
f
" With exception
{
e
}
"
raise
error_message
async
def
generate_request
(
async
def
generate_request
(
self
,
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment