Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
81d27c8e
"vscode:/vscode.git/clone" did not exist on "6fc9d2270786a76952a5d55ec75a9f9dcde73e26"
Unverified
Commit
81d27c8e
authored
Jan 19, 2025
by
fzyzcjy
Committed by
GitHub
Jan 18, 2025
Browse files
Refactor to add TypeBasedDispatcher to simplify dispatching (#2958)
parent
4d4cdb3f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
171 additions
and
164 deletions
+171
-164
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+52
-61
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+107
-102
python/sglang/utils.py
python/sglang/utils.py
+12
-1
No files found.
python/sglang/srt/managers/scheduler.py
View file @
81d27c8e
...
...
@@ -97,7 +97,7 @@ from sglang.srt.utils import (
set_random_seed
,
suppress_other_loggers
,
)
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
TypeBasedDispatcher
,
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -422,6 +422,34 @@ class Scheduler:
},
)
self
.
_dispatcher
=
TypeBasedDispatcher
(
[
(
TokenizedGenerateReqInput
,
self
.
handle_generate_request
),
(
TokenizedEmbeddingReqInput
,
self
.
handle_embedding_request
),
(
FlushCacheReq
,
self
.
flush_cache_wrapped
),
(
AbortReq
,
self
.
abort_request
),
(
UpdateWeightFromDiskReqInput
,
self
.
update_weights_from_disk
),
(
InitWeightsUpdateGroupReqInput
,
self
.
init_weights_update_group
),
(
UpdateWeightsFromDistributedReqInput
,
self
.
update_weights_from_distributed
,
),
(
UpdateWeightsFromTensorReqInput
,
self
.
update_weights_from_tensor
),
(
GetWeightsByNameReqInput
,
self
.
get_weights_by_name
),
(
ProfileReq
,
self
.
profile
),
(
OpenSessionReqInput
,
self
.
open_session
),
(
CloseSessionReqInput
,
self
.
close_session
),
(
ReleaseMemoryOccupationReqInput
,
lambda
_
:
self
.
release_memory_occupation
(),
),
(
ResumeMemoryOccupationReqInput
,
lambda
_
:
self
.
resume_memory_occupation
(),
),
]
)
def
watchdog_thread
(
self
):
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
self
.
watchdog_last_forward_ct
=
0
...
...
@@ -563,57 +591,9 @@ class Scheduler:
def
process_input_requests
(
self
,
recv_reqs
:
List
):
for
recv_req
in
recv_reqs
:
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
self
.
handle_generate_request
(
recv_req
)
elif
isinstance
(
recv_req
,
TokenizedEmbeddingReqInput
):
self
.
handle_embedding_request
(
recv_req
)
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
elif
isinstance
(
recv_req
,
AbortReq
):
self
.
abort_request
(
recv_req
)
elif
isinstance
(
recv_req
,
UpdateWeightFromDiskReqInput
):
success
,
message
=
self
.
update_weights_from_disk
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
UpdateWeightFromDiskReqOutput
(
success
,
message
)
)
elif
isinstance
(
recv_req
,
InitWeightsUpdateGroupReqInput
):
success
,
message
=
self
.
init_weights_update_group
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
InitWeightsUpdateGroupReqOutput
(
success
,
message
)
)
elif
isinstance
(
recv_req
,
UpdateWeightsFromDistributedReqInput
):
success
,
message
=
self
.
update_weights_from_distributed
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
UpdateWeightsFromDistributedReqOutput
(
success
,
message
)
)
elif
isinstance
(
recv_req
,
UpdateWeightsFromTensorReqInput
):
success
,
message
=
self
.
update_weights_from_tensor
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
UpdateWeightsFromTensorReqOutput
(
success
,
message
)
)
elif
isinstance
(
recv_req
,
GetWeightsByNameReqInput
):
parameter
=
self
.
get_weights_by_name
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
GetWeightsByNameReqOutput
(
parameter
))
elif
isinstance
(
recv_req
,
ReleaseMemoryOccupationReqInput
):
self
.
release_memory_occupation
()
self
.
send_to_tokenizer
.
send_pyobj
(
ReleaseMemoryOccupationReqOutput
())
elif
isinstance
(
recv_req
,
ResumeMemoryOccupationReqInput
):
self
.
resume_memory_occupation
()
self
.
send_to_tokenizer
.
send_pyobj
(
ResumeMemoryOccupationReqOutput
())
elif
isinstance
(
recv_req
,
ProfileReq
):
if
recv_req
==
ProfileReq
.
START_PROFILE
:
self
.
start_profile
()
else
:
self
.
stop_profile
()
elif
isinstance
(
recv_req
,
OpenSessionReqInput
):
session_id
,
success
=
self
.
open_session
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
OpenSessionReqOutput
(
session_id
=
session_id
,
success
=
success
)
)
elif
isinstance
(
recv_req
,
CloseSessionReqInput
):
self
.
close_session
(
recv_req
)
else
:
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
output
=
self
.
_dispatcher
(
recv_req
)
if
output
is
not
None
:
self
.
send_to_tokenizer
.
send_pyobj
(
output
)
def
handle_generate_request
(
self
,
...
...
@@ -1545,6 +1525,9 @@ class Scheduler:
self
.
waiting_queue
.
extend
(
self
.
grammar_queue
[:
num_ready_reqs
])
self
.
grammar_queue
=
self
.
grammar_queue
[
num_ready_reqs
:]
def
flush_cache_wrapped
(
self
,
recv_req
:
FlushCacheReq
):
self
.
flush_cache
()
def
flush_cache
(
self
):
"""Flush the memory pool and cache."""
if
len
(
self
.
waiting_queue
)
==
0
and
(
...
...
@@ -1597,12 +1580,12 @@ class Scheduler:
assert
flash_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
success
,
message
return
UpdateWeightFromDiskReqOutput
(
success
,
message
)
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
"""Initialize the online model parameter update group."""
success
,
message
=
self
.
tp_worker
.
init_weights_update_group
(
recv_req
)
return
success
,
message
return
InitWeightsUpdateGroupReqOutput
(
success
,
message
)
def
update_weights_from_distributed
(
self
,
...
...
@@ -1615,7 +1598,7 @@ class Scheduler:
assert
flash_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
success
,
message
return
UpdateWeightsFromDistributedReqOutput
(
success
,
message
)
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
"""Update the online model parameter from tensors."""
...
...
@@ -1626,11 +1609,11 @@ class Scheduler:
assert
flash_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
success
,
message
return
UpdateWeightsFromTensorReqOutput
(
success
,
message
)
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
parameter
=
self
.
tp_worker
.
get_weights_by_name
(
recv_req
)
return
parameter
return
GetWeightsByNameReqOutput
(
parameter
)
def
release_memory_occupation
(
self
):
self
.
stashed_model_static_state
=
_export_static_state
(
...
...
@@ -1638,6 +1621,7 @@ class Scheduler:
)
self
.
memory_saver_adapter
.
pause
()
self
.
flush_cache
()
return
ReleaseMemoryOccupationReqOutput
()
def
resume_memory_occupation
(
self
):
self
.
memory_saver_adapter
.
resume
()
...
...
@@ -1645,6 +1629,13 @@ class Scheduler:
self
.
tp_worker
.
worker
.
model_runner
.
model
,
self
.
stashed_model_static_state
)
del
self
.
stashed_model_static_state
return
ResumeMemoryOccupationReqOutput
()
def
profile
(
self
,
recv_req
:
ProfileReq
):
if
recv_req
==
ProfileReq
.
START_PROFILE
:
self
.
start_profile
()
else
:
self
.
stop_profile
()
def
start_profile
(
self
)
->
None
:
if
self
.
profiler
is
None
:
...
...
@@ -1660,20 +1651,20 @@ class Scheduler:
)
logger
.
info
(
"Profiler is done"
)
def
open_session
(
self
,
recv_req
:
OpenSessionReqInput
)
->
Tuple
[
Optional
[
str
],
bool
]
:
def
open_session
(
self
,
recv_req
:
OpenSessionReqInput
):
# handle error
session_id
=
recv_req
.
session_id
if
session_id
in
self
.
sessions
:
logger
.
warning
(
f
"session id
{
session_id
}
already exist, cannot open."
)
return
session_id
,
False
return
OpenSessionReqOutput
(
session_id
,
False
)
elif
session_id
is
None
:
logger
.
warning
(
f
"session id is None, cannot open."
)
return
session_id
,
False
return
OpenSessionReqOutput
(
session_id
,
False
)
else
:
self
.
sessions
[
session_id
]
=
Session
(
recv_req
.
capacity_of_str_len
,
session_id
)
return
session_id
,
True
return
OpenSessionReqOutput
(
session_id
,
True
)
def
close_session
(
self
,
recv_req
:
CloseSessionReqInput
):
# handle error
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
81d27c8e
...
...
@@ -80,7 +80,7 @@ from sglang.srt.utils import (
get_zmq_socket
,
kill_process_tree
,
)
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
TypeBasedDispatcher
,
get_exception_traceback
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
...
@@ -221,6 +221,43 @@ class TokenizerManager:
},
)
self
.
_dispatcher
=
TypeBasedDispatcher
(
[
(
BatchStrOut
,
self
.
_handle_batch_output
),
(
BatchEmbeddingOut
,
self
.
_handle_batch_output
),
(
BatchTokenIDOut
,
self
.
_handle_batch_output
),
(
OpenSessionReqOutput
,
self
.
_handle_open_session_req_output
),
(
UpdateWeightFromDiskReqOutput
,
self
.
_handle_update_weights_from_disk_req_output
,
),
(
InitWeightsUpdateGroupReqOutput
,
self
.
init_weights_update_group_communicator
.
handle_recv
,
),
(
UpdateWeightsFromDistributedReqOutput
,
self
.
update_weights_from_distributed_communicator
.
handle_recv
,
),
(
UpdateWeightsFromTensorReqOutput
,
self
.
update_weights_from_tensor_communicator
.
handle_recv
,
),
(
GetWeightsByNameReqOutput
,
self
.
get_weights_by_name_communicator
.
handle_recv
,
),
(
ReleaseMemoryOccupationReqOutput
,
self
.
release_memory_occupation_communicator
.
handle_recv
,
),
(
ResumeMemoryOccupationReqOutput
,
self
.
resume_memory_occupation_communicator
.
handle_recv
,
),
]
)
async
def
generate_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
...
...
@@ -712,19 +749,12 @@ class TokenizerManager:
"""The event loop that handles requests"""
while
True
:
recv_obj
:
Union
[
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightsFromDistributedReqOutput
,
GetWeightsByNameReqOutput
,
InitWeightsUpdateGroupReqOutput
,
ReleaseMemoryOccupationReqOutput
,
ResumeMemoryOccupationReqOutput
,
]
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
recv_obj
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
self
.
_dispatcher
(
recv_obj
)
if
isinstance
(
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
)):
def
_handle_batch_output
(
self
,
recv_obj
:
Union
[
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
]
):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
if
state
is
None
:
...
...
@@ -775,47 +805,8 @@ class TokenizerManager:
if
self
.
enable_metrics
and
state
.
obj
.
log_metrics
:
self
.
collect_metrics
(
state
,
recv_obj
,
i
)
if
(
self
.
dump_requests_folder
and
state
.
finished
and
state
.
obj
.
log_metrics
):
if
self
.
dump_requests_folder
and
state
.
finished
and
state
.
obj
.
log_metrics
:
self
.
dump_requests
(
state
,
out_dict
)
elif
isinstance
(
recv_obj
,
OpenSessionReqOutput
):
self
.
session_futures
[
recv_obj
.
session_id
].
set_result
(
recv_obj
.
session_id
if
recv_obj
.
success
else
None
)
elif
isinstance
(
recv_obj
,
UpdateWeightFromDiskReqOutput
):
if
self
.
server_args
.
dp_size
==
1
:
self
.
model_update_result
.
set_result
(
recv_obj
)
else
:
# self.server_args.dp_size > 1
self
.
model_update_tmp
.
append
(
recv_obj
)
# set future if the all results are recevied
if
len
(
self
.
model_update_tmp
)
==
self
.
server_args
.
dp_size
:
self
.
model_update_result
.
set_result
(
self
.
model_update_tmp
)
elif
isinstance
(
recv_obj
,
InitWeightsUpdateGroupReqOutput
):
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for init parameter update group"
self
.
init_weights_update_group_communicator
.
handle_recv
(
recv_obj
)
elif
isinstance
(
recv_obj
,
UpdateWeightsFromDistributedReqOutput
):
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for update weights from distributed"
self
.
update_weights_from_distributed_communicator
.
handle_recv
(
recv_obj
)
elif
isinstance
(
recv_obj
,
UpdateWeightsFromTensorReqOutput
):
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for update weights from distributed"
self
.
update_weights_from_tensor_communicator
.
handle_recv
(
recv_obj
)
elif
isinstance
(
recv_obj
,
GetWeightsByNameReqOutput
):
self
.
get_weights_by_name_communicator
.
handle_recv
(
recv_obj
)
elif
isinstance
(
recv_obj
,
ReleaseMemoryOccupationReqOutput
):
self
.
release_memory_occupation_communicator
.
handle_recv
(
recv_obj
)
elif
isinstance
(
recv_obj
,
ResumeMemoryOccupationReqOutput
):
self
.
resume_memory_occupation_communicator
.
handle_recv
(
recv_obj
)
else
:
raise
ValueError
(
f
"Invalid object:
{
recv_obj
=
}
"
)
def
convert_logprob_style
(
self
,
...
...
@@ -943,6 +934,20 @@ class TokenizerManager:
# Schedule the task to run in the background without awaiting it
asyncio
.
create_task
(
asyncio
.
to_thread
(
background_task
))
def
_handle_open_session_req_output
(
self
,
recv_obj
):
self
.
session_futures
[
recv_obj
.
session_id
].
set_result
(
recv_obj
.
session_id
if
recv_obj
.
success
else
None
)
def
_handle_update_weights_from_disk_req_output
(
self
,
recv_obj
):
if
self
.
server_args
.
dp_size
==
1
:
self
.
model_update_result
.
set_result
(
recv_obj
)
else
:
# self.server_args.dp_size > 1
self
.
model_update_tmp
.
append
(
recv_obj
)
# set future if the all results are recevied
if
len
(
self
.
model_update_tmp
)
==
self
.
server_args
.
dp_size
:
self
.
model_update_result
.
set_result
(
self
.
model_update_tmp
)
async
def
print_exception_wrapper
(
func
):
"""
...
...
python/sglang/utils.py
View file @
81d27c8e
...
...
@@ -15,7 +15,7 @@ import urllib.request
from
concurrent.futures
import
ThreadPoolExecutor
from
io
import
BytesIO
from
json
import
dumps
from
typing
import
Optional
,
Union
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
numpy
as
np
import
requests
...
...
@@ -363,3 +363,14 @@ def terminate_process(process):
def
print_highlight
(
html_content
:
str
):
html_content
=
str
(
html_content
).
replace
(
"
\n
"
,
"<br>"
)
display
(
HTML
(
f
"<strong style='color: #00008B;'>
{
html_content
}
</strong>"
))
class
TypeBasedDispatcher
:
def
__init__
(
self
,
mapping
:
List
[
Tuple
[
Type
,
Callable
]]):
self
.
_mapping
=
mapping
def
__call__
(
self
,
obj
:
Any
):
for
ty
,
fn
in
self
.
_mapping
:
if
isinstance
(
obj
,
ty
):
return
fn
(
obj
)
raise
ValueError
(
f
"Invalid object:
{
obj
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment