Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
81d27c8e
"vscode:/vscode.git/clone" did not exist on "b1fbef544c993288447181a6c8d8c68d89387ebe"
Unverified
Commit
81d27c8e
authored
Jan 19, 2025
by
fzyzcjy
Committed by
GitHub
Jan 18, 2025
Browse files
Refactor to add TypeBasedDispatcher to simplify dispatching (#2958)
parent
4d4cdb3f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
171 additions
and
164 deletions
+171
-164
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+52
-61
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+107
-102
python/sglang/utils.py
python/sglang/utils.py
+12
-1
No files found.
python/sglang/srt/managers/scheduler.py
View file @
81d27c8e
...
@@ -97,7 +97,7 @@ from sglang.srt.utils import (
...
@@ -97,7 +97,7 @@ from sglang.srt.utils import (
set_random_seed
,
set_random_seed
,
suppress_other_loggers
,
suppress_other_loggers
,
)
)
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
TypeBasedDispatcher
,
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -422,6 +422,34 @@ class Scheduler:
...
@@ -422,6 +422,34 @@ class Scheduler:
},
},
)
)
self
.
_dispatcher
=
TypeBasedDispatcher
(
[
(
TokenizedGenerateReqInput
,
self
.
handle_generate_request
),
(
TokenizedEmbeddingReqInput
,
self
.
handle_embedding_request
),
(
FlushCacheReq
,
self
.
flush_cache_wrapped
),
(
AbortReq
,
self
.
abort_request
),
(
UpdateWeightFromDiskReqInput
,
self
.
update_weights_from_disk
),
(
InitWeightsUpdateGroupReqInput
,
self
.
init_weights_update_group
),
(
UpdateWeightsFromDistributedReqInput
,
self
.
update_weights_from_distributed
,
),
(
UpdateWeightsFromTensorReqInput
,
self
.
update_weights_from_tensor
),
(
GetWeightsByNameReqInput
,
self
.
get_weights_by_name
),
(
ProfileReq
,
self
.
profile
),
(
OpenSessionReqInput
,
self
.
open_session
),
(
CloseSessionReqInput
,
self
.
close_session
),
(
ReleaseMemoryOccupationReqInput
,
lambda
_
:
self
.
release_memory_occupation
(),
),
(
ResumeMemoryOccupationReqInput
,
lambda
_
:
self
.
resume_memory_occupation
(),
),
]
)
def
watchdog_thread
(
self
):
def
watchdog_thread
(
self
):
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
self
.
watchdog_last_forward_ct
=
0
self
.
watchdog_last_forward_ct
=
0
...
@@ -563,57 +591,9 @@ class Scheduler:
...
@@ -563,57 +591,9 @@ class Scheduler:
def
process_input_requests
(
self
,
recv_reqs
:
List
):
def
process_input_requests
(
self
,
recv_reqs
:
List
):
for
recv_req
in
recv_reqs
:
for
recv_req
in
recv_reqs
:
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
output
=
self
.
_dispatcher
(
recv_req
)
self
.
handle_generate_request
(
recv_req
)
if
output
is
not
None
:
elif
isinstance
(
recv_req
,
TokenizedEmbeddingReqInput
):
self
.
send_to_tokenizer
.
send_pyobj
(
output
)
self
.
handle_embedding_request
(
recv_req
)
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
elif
isinstance
(
recv_req
,
AbortReq
):
self
.
abort_request
(
recv_req
)
elif
isinstance
(
recv_req
,
UpdateWeightFromDiskReqInput
):
success
,
message
=
self
.
update_weights_from_disk
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
UpdateWeightFromDiskReqOutput
(
success
,
message
)
)
elif
isinstance
(
recv_req
,
InitWeightsUpdateGroupReqInput
):
success
,
message
=
self
.
init_weights_update_group
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
InitWeightsUpdateGroupReqOutput
(
success
,
message
)
)
elif
isinstance
(
recv_req
,
UpdateWeightsFromDistributedReqInput
):
success
,
message
=
self
.
update_weights_from_distributed
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
UpdateWeightsFromDistributedReqOutput
(
success
,
message
)
)
elif
isinstance
(
recv_req
,
UpdateWeightsFromTensorReqInput
):
success
,
message
=
self
.
update_weights_from_tensor
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
UpdateWeightsFromTensorReqOutput
(
success
,
message
)
)
elif
isinstance
(
recv_req
,
GetWeightsByNameReqInput
):
parameter
=
self
.
get_weights_by_name
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
GetWeightsByNameReqOutput
(
parameter
))
elif
isinstance
(
recv_req
,
ReleaseMemoryOccupationReqInput
):
self
.
release_memory_occupation
()
self
.
send_to_tokenizer
.
send_pyobj
(
ReleaseMemoryOccupationReqOutput
())
elif
isinstance
(
recv_req
,
ResumeMemoryOccupationReqInput
):
self
.
resume_memory_occupation
()
self
.
send_to_tokenizer
.
send_pyobj
(
ResumeMemoryOccupationReqOutput
())
elif
isinstance
(
recv_req
,
ProfileReq
):
if
recv_req
==
ProfileReq
.
START_PROFILE
:
self
.
start_profile
()
else
:
self
.
stop_profile
()
elif
isinstance
(
recv_req
,
OpenSessionReqInput
):
session_id
,
success
=
self
.
open_session
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
OpenSessionReqOutput
(
session_id
=
session_id
,
success
=
success
)
)
elif
isinstance
(
recv_req
,
CloseSessionReqInput
):
self
.
close_session
(
recv_req
)
else
:
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
def
handle_generate_request
(
def
handle_generate_request
(
self
,
self
,
...
@@ -1545,6 +1525,9 @@ class Scheduler:
...
@@ -1545,6 +1525,9 @@ class Scheduler:
self
.
waiting_queue
.
extend
(
self
.
grammar_queue
[:
num_ready_reqs
])
self
.
waiting_queue
.
extend
(
self
.
grammar_queue
[:
num_ready_reqs
])
self
.
grammar_queue
=
self
.
grammar_queue
[
num_ready_reqs
:]
self
.
grammar_queue
=
self
.
grammar_queue
[
num_ready_reqs
:]
def
flush_cache_wrapped
(
self
,
recv_req
:
FlushCacheReq
):
self
.
flush_cache
()
def
flush_cache
(
self
):
def
flush_cache
(
self
):
"""Flush the memory pool and cache."""
"""Flush the memory pool and cache."""
if
len
(
self
.
waiting_queue
)
==
0
and
(
if
len
(
self
.
waiting_queue
)
==
0
and
(
...
@@ -1597,12 +1580,12 @@ class Scheduler:
...
@@ -1597,12 +1580,12 @@ class Scheduler:
assert
flash_cache_success
,
"Cache flush failed after updating weights"
assert
flash_cache_success
,
"Cache flush failed after updating weights"
else
:
else
:
logger
.
error
(
message
)
logger
.
error
(
message
)
return
success
,
message
return
UpdateWeightFromDiskReqOutput
(
success
,
message
)
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
"""Initialize the online model parameter update group."""
"""Initialize the online model parameter update group."""
success
,
message
=
self
.
tp_worker
.
init_weights_update_group
(
recv_req
)
success
,
message
=
self
.
tp_worker
.
init_weights_update_group
(
recv_req
)
return
success
,
message
return
InitWeightsUpdateGroupReqOutput
(
success
,
message
)
def
update_weights_from_distributed
(
def
update_weights_from_distributed
(
self
,
self
,
...
@@ -1615,7 +1598,7 @@ class Scheduler:
...
@@ -1615,7 +1598,7 @@ class Scheduler:
assert
flash_cache_success
,
"Cache flush failed after updating weights"
assert
flash_cache_success
,
"Cache flush failed after updating weights"
else
:
else
:
logger
.
error
(
message
)
logger
.
error
(
message
)
return
success
,
message
return
UpdateWeightsFromDistributedReqOutput
(
success
,
message
)
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
"""Update the online model parameter from tensors."""
"""Update the online model parameter from tensors."""
...
@@ -1626,11 +1609,11 @@ class Scheduler:
...
@@ -1626,11 +1609,11 @@ class Scheduler:
assert
flash_cache_success
,
"Cache flush failed after updating weights"
assert
flash_cache_success
,
"Cache flush failed after updating weights"
else
:
else
:
logger
.
error
(
message
)
logger
.
error
(
message
)
return
success
,
message
return
UpdateWeightsFromTensorReqOutput
(
success
,
message
)
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
parameter
=
self
.
tp_worker
.
get_weights_by_name
(
recv_req
)
parameter
=
self
.
tp_worker
.
get_weights_by_name
(
recv_req
)
return
parameter
return
GetWeightsByNameReqOutput
(
parameter
)
def
release_memory_occupation
(
self
):
def
release_memory_occupation
(
self
):
self
.
stashed_model_static_state
=
_export_static_state
(
self
.
stashed_model_static_state
=
_export_static_state
(
...
@@ -1638,6 +1621,7 @@ class Scheduler:
...
@@ -1638,6 +1621,7 @@ class Scheduler:
)
)
self
.
memory_saver_adapter
.
pause
()
self
.
memory_saver_adapter
.
pause
()
self
.
flush_cache
()
self
.
flush_cache
()
return
ReleaseMemoryOccupationReqOutput
()
def
resume_memory_occupation
(
self
):
def
resume_memory_occupation
(
self
):
self
.
memory_saver_adapter
.
resume
()
self
.
memory_saver_adapter
.
resume
()
...
@@ -1645,6 +1629,13 @@ class Scheduler:
...
@@ -1645,6 +1629,13 @@ class Scheduler:
self
.
tp_worker
.
worker
.
model_runner
.
model
,
self
.
stashed_model_static_state
self
.
tp_worker
.
worker
.
model_runner
.
model
,
self
.
stashed_model_static_state
)
)
del
self
.
stashed_model_static_state
del
self
.
stashed_model_static_state
return
ResumeMemoryOccupationReqOutput
()
def
profile
(
self
,
recv_req
:
ProfileReq
):
if
recv_req
==
ProfileReq
.
START_PROFILE
:
self
.
start_profile
()
else
:
self
.
stop_profile
()
def
start_profile
(
self
)
->
None
:
def
start_profile
(
self
)
->
None
:
if
self
.
profiler
is
None
:
if
self
.
profiler
is
None
:
...
@@ -1660,20 +1651,20 @@ class Scheduler:
...
@@ -1660,20 +1651,20 @@ class Scheduler:
)
)
logger
.
info
(
"Profiler is done"
)
logger
.
info
(
"Profiler is done"
)
def
open_session
(
self
,
recv_req
:
OpenSessionReqInput
)
->
Tuple
[
Optional
[
str
],
bool
]
:
def
open_session
(
self
,
recv_req
:
OpenSessionReqInput
):
# handle error
# handle error
session_id
=
recv_req
.
session_id
session_id
=
recv_req
.
session_id
if
session_id
in
self
.
sessions
:
if
session_id
in
self
.
sessions
:
logger
.
warning
(
f
"session id
{
session_id
}
already exist, cannot open."
)
logger
.
warning
(
f
"session id
{
session_id
}
already exist, cannot open."
)
return
session_id
,
False
return
OpenSessionReqOutput
(
session_id
,
False
)
elif
session_id
is
None
:
elif
session_id
is
None
:
logger
.
warning
(
f
"session id is None, cannot open."
)
logger
.
warning
(
f
"session id is None, cannot open."
)
return
session_id
,
False
return
OpenSessionReqOutput
(
session_id
,
False
)
else
:
else
:
self
.
sessions
[
session_id
]
=
Session
(
self
.
sessions
[
session_id
]
=
Session
(
recv_req
.
capacity_of_str_len
,
session_id
recv_req
.
capacity_of_str_len
,
session_id
)
)
return
session_id
,
True
return
OpenSessionReqOutput
(
session_id
,
True
)
def
close_session
(
self
,
recv_req
:
CloseSessionReqInput
):
def
close_session
(
self
,
recv_req
:
CloseSessionReqInput
):
# handle error
# handle error
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
81d27c8e
...
@@ -80,7 +80,7 @@ from sglang.srt.utils import (
...
@@ -80,7 +80,7 @@ from sglang.srt.utils import (
get_zmq_socket
,
get_zmq_socket
,
kill_process_tree
,
kill_process_tree
,
)
)
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
TypeBasedDispatcher
,
get_exception_traceback
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
@@ -221,6 +221,43 @@ class TokenizerManager:
...
@@ -221,6 +221,43 @@ class TokenizerManager:
},
},
)
)
self
.
_dispatcher
=
TypeBasedDispatcher
(
[
(
BatchStrOut
,
self
.
_handle_batch_output
),
(
BatchEmbeddingOut
,
self
.
_handle_batch_output
),
(
BatchTokenIDOut
,
self
.
_handle_batch_output
),
(
OpenSessionReqOutput
,
self
.
_handle_open_session_req_output
),
(
UpdateWeightFromDiskReqOutput
,
self
.
_handle_update_weights_from_disk_req_output
,
),
(
InitWeightsUpdateGroupReqOutput
,
self
.
init_weights_update_group_communicator
.
handle_recv
,
),
(
UpdateWeightsFromDistributedReqOutput
,
self
.
update_weights_from_distributed_communicator
.
handle_recv
,
),
(
UpdateWeightsFromTensorReqOutput
,
self
.
update_weights_from_tensor_communicator
.
handle_recv
,
),
(
GetWeightsByNameReqOutput
,
self
.
get_weights_by_name_communicator
.
handle_recv
,
),
(
ReleaseMemoryOccupationReqOutput
,
self
.
release_memory_occupation_communicator
.
handle_recv
,
),
(
ResumeMemoryOccupationReqOutput
,
self
.
resume_memory_occupation_communicator
.
handle_recv
,
),
]
)
async
def
generate_request
(
async
def
generate_request
(
self
,
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
...
@@ -712,110 +749,64 @@ class TokenizerManager:
...
@@ -712,110 +749,64 @@ class TokenizerManager:
"""The event loop that handles requests"""
"""The event loop that handles requests"""
while
True
:
while
True
:
recv_obj
:
Union
[
recv_obj
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
BatchStrOut
,
self
.
_dispatcher
(
recv_obj
)
BatchEmbeddingOut
,
BatchTokenIDOut
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightsFromDistributedReqOutput
,
GetWeightsByNameReqOutput
,
InitWeightsUpdateGroupReqOutput
,
ReleaseMemoryOccupationReqOutput
,
ResumeMemoryOccupationReqOutput
,
]
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
if
isinstance
(
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
)):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
if
state
is
None
:
continue
meta_info
=
{
"id"
:
rid
,
"finish_reason"
:
recv_obj
.
finished_reasons
[
i
],
"prompt_tokens"
:
recv_obj
.
prompt_tokens
[
i
],
}
if
getattr
(
state
.
obj
,
"return_logprob"
,
False
):
def
_handle_batch_output
(
self
.
convert_logprob_style
(
self
,
recv_obj
:
Union
[
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
]
meta_info
,
):
state
.
obj
.
top_logprobs_num
,
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
state
.
obj
.
return_text_in_logprobs
,
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
recv_obj
,
if
state
is
None
:
i
,
continue
)
meta_info
=
{
if
not
isinstance
(
recv_obj
,
BatchEmbeddingOut
):
"id"
:
rid
,
meta_info
.
update
(
"finish_reason"
:
recv_obj
.
finished_reasons
[
i
],
{
"prompt_tokens"
:
recv_obj
.
prompt_tokens
[
i
],
"completion_tokens"
:
recv_obj
.
completion_tokens
[
i
],
}
"cached_tokens"
:
recv_obj
.
cached_tokens
[
i
],
}
if
getattr
(
state
.
obj
,
"return_logprob"
,
False
):
)
self
.
convert_logprob_style
(
meta_info
,
if
isinstance
(
recv_obj
,
BatchStrOut
):
state
.
obj
.
top_logprobs_num
,
out_dict
=
{
state
.
obj
.
return_text_in_logprobs
,
"text"
:
recv_obj
.
output_strs
[
i
],
recv_obj
,
"meta_info"
:
meta_info
,
i
,
}
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
):
out_dict
=
{
"token_ids"
:
recv_obj
.
output_ids
[
i
],
"meta_info"
:
meta_info
,
}
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
out_dict
=
{
"embedding"
:
recv_obj
.
embeddings
[
i
],
"meta_info"
:
meta_info
,
}
state
.
out_list
.
append
(
out_dict
)
state
.
finished
=
recv_obj
.
finished_reasons
[
i
]
is
not
None
state
.
event
.
set
()
if
self
.
enable_metrics
and
state
.
obj
.
log_metrics
:
self
.
collect_metrics
(
state
,
recv_obj
,
i
)
if
(
self
.
dump_requests_folder
and
state
.
finished
and
state
.
obj
.
log_metrics
):
self
.
dump_requests
(
state
,
out_dict
)
elif
isinstance
(
recv_obj
,
OpenSessionReqOutput
):
self
.
session_futures
[
recv_obj
.
session_id
].
set_result
(
recv_obj
.
session_id
if
recv_obj
.
success
else
None
)
)
elif
isinstance
(
recv_obj
,
UpdateWeightFromDiskReqOutput
):
if
self
.
server_args
.
dp_size
==
1
:
if
not
isinstance
(
recv_obj
,
BatchEmbeddingOut
):
self
.
model_update_result
.
set_result
(
recv_obj
)
meta_info
.
update
(
else
:
# self.server_args.dp_size > 1
{
self
.
model_update_tmp
.
append
(
recv_obj
)
"completion_tokens"
:
recv_obj
.
completion_tokens
[
i
],
# set future if the all results are recevied
"cached_tokens"
:
recv_obj
.
cached_tokens
[
i
],
if
len
(
self
.
model_update_tmp
)
==
self
.
server_args
.
dp_size
:
}
self
.
model_update_result
.
set_result
(
self
.
model_update_tmp
)
)
elif
isinstance
(
recv_obj
,
InitWeightsUpdateGroupReqOutput
):
assert
(
if
isinstance
(
recv_obj
,
BatchStrOut
):
self
.
server_args
.
dp_size
==
1
out_dict
=
{
),
"dp_size must be 1 for init parameter update group"
"text"
:
recv_obj
.
output_strs
[
i
],
self
.
init_weights_update_group_communicator
.
handle_recv
(
recv_obj
)
"meta_info"
:
meta_info
,
elif
isinstance
(
recv_obj
,
UpdateWeightsFromDistributedReqOutput
):
}
assert
(
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
):
self
.
server_args
.
dp_size
==
1
out_dict
=
{
),
"dp_size must be 1 for update weights from distributed"
"token_ids"
:
recv_obj
.
output_ids
[
i
],
self
.
update_weights_from_distributed_communicator
.
handle_recv
(
recv_obj
)
"meta_info"
:
meta_info
,
elif
isinstance
(
recv_obj
,
UpdateWeightsFromTensorReqOutput
):
}
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for update weights from distributed"
self
.
update_weights_from_tensor_communicator
.
handle_recv
(
recv_obj
)
elif
isinstance
(
recv_obj
,
GetWeightsByNameReqOutput
):
self
.
get_weights_by_name_communicator
.
handle_recv
(
recv_obj
)
elif
isinstance
(
recv_obj
,
ReleaseMemoryOccupationReqOutput
):
self
.
release_memory_occupation_communicator
.
handle_recv
(
recv_obj
)
elif
isinstance
(
recv_obj
,
ResumeMemoryOccupationReqOutput
):
self
.
resume_memory_occupation_communicator
.
handle_recv
(
recv_obj
)
else
:
else
:
raise
ValueError
(
f
"Invalid object:
{
recv_obj
=
}
"
)
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
out_dict
=
{
"embedding"
:
recv_obj
.
embeddings
[
i
],
"meta_info"
:
meta_info
,
}
state
.
out_list
.
append
(
out_dict
)
state
.
finished
=
recv_obj
.
finished_reasons
[
i
]
is
not
None
state
.
event
.
set
()
if
self
.
enable_metrics
and
state
.
obj
.
log_metrics
:
self
.
collect_metrics
(
state
,
recv_obj
,
i
)
if
self
.
dump_requests_folder
and
state
.
finished
and
state
.
obj
.
log_metrics
:
self
.
dump_requests
(
state
,
out_dict
)
def
convert_logprob_style
(
def
convert_logprob_style
(
self
,
self
,
...
@@ -943,6 +934,20 @@ class TokenizerManager:
...
@@ -943,6 +934,20 @@ class TokenizerManager:
# Schedule the task to run in the background without awaiting it
# Schedule the task to run in the background without awaiting it
asyncio
.
create_task
(
asyncio
.
to_thread
(
background_task
))
asyncio
.
create_task
(
asyncio
.
to_thread
(
background_task
))
def
_handle_open_session_req_output
(
self
,
recv_obj
):
self
.
session_futures
[
recv_obj
.
session_id
].
set_result
(
recv_obj
.
session_id
if
recv_obj
.
success
else
None
)
def
_handle_update_weights_from_disk_req_output
(
self
,
recv_obj
):
if
self
.
server_args
.
dp_size
==
1
:
self
.
model_update_result
.
set_result
(
recv_obj
)
else
:
# self.server_args.dp_size > 1
self
.
model_update_tmp
.
append
(
recv_obj
)
# set future if the all results are recevied
if
len
(
self
.
model_update_tmp
)
==
self
.
server_args
.
dp_size
:
self
.
model_update_result
.
set_result
(
self
.
model_update_tmp
)
async
def
print_exception_wrapper
(
func
):
async
def
print_exception_wrapper
(
func
):
"""
"""
...
...
python/sglang/utils.py
View file @
81d27c8e
...
@@ -15,7 +15,7 @@ import urllib.request
...
@@ -15,7 +15,7 @@ import urllib.request
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
io
import
BytesIO
from
io
import
BytesIO
from
json
import
dumps
from
json
import
dumps
from
typing
import
Optional
,
Union
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
numpy
as
np
import
numpy
as
np
import
requests
import
requests
...
@@ -363,3 +363,14 @@ def terminate_process(process):
...
@@ -363,3 +363,14 @@ def terminate_process(process):
def
print_highlight
(
html_content
:
str
):
def
print_highlight
(
html_content
:
str
):
html_content
=
str
(
html_content
).
replace
(
"
\n
"
,
"<br>"
)
html_content
=
str
(
html_content
).
replace
(
"
\n
"
,
"<br>"
)
display
(
HTML
(
f
"<strong style='color: #00008B;'>
{
html_content
}
</strong>"
))
display
(
HTML
(
f
"<strong style='color: #00008B;'>
{
html_content
}
</strong>"
))
class
TypeBasedDispatcher
:
def
__init__
(
self
,
mapping
:
List
[
Tuple
[
Type
,
Callable
]]):
self
.
_mapping
=
mapping
def
__call__
(
self
,
obj
:
Any
):
for
ty
,
fn
in
self
.
_mapping
:
if
isinstance
(
obj
,
ty
):
return
fn
(
obj
)
raise
ValueError
(
f
"Invalid object:
{
obj
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment