Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b2ed5c8e
Unverified
Commit
b2ed5c8e
authored
Dec 27, 2024
by
fzyzcjy
Committed by
GitHub
Dec 26, 2024
Browse files
Tiny code cleanup in tokenizer_manager.py (#2586)
parent
f46f394f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
82 deletions
+74
-82
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+54
-45
python/sglang/srt/server.py
python/sglang/srt/server.py
+20
-37
No files found.
python/sglang/srt/managers/tokenizer_manager.py
View file @
b2ed5c8e
...
@@ -22,7 +22,7 @@ import signal
...
@@ -22,7 +22,7 @@ import signal
import
sys
import
sys
import
time
import
time
import
uuid
import
uuid
from
typing
import
Any
,
Awaitable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Awaitable
,
Dict
,
Generic
,
List
,
Optional
,
Tuple
,
TypeVar
,
Union
import
fastapi
import
fastapi
import
uvloop
import
uvloop
...
@@ -173,6 +173,15 @@ class TokenizerManager:
...
@@ -173,6 +173,15 @@ class TokenizerManager:
# Others
# Others
self
.
gracefully_exit
=
False
self
.
gracefully_exit
=
False
self
.
init_weights_update_group_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
update_weights_from_distributed_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
get_weights_by_name_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
# Metrics
# Metrics
if
self
.
enable_metrics
:
if
self
.
enable_metrics
:
...
@@ -190,8 +199,7 @@ class TokenizerManager:
...
@@ -190,8 +199,7 @@ class TokenizerManager:
):
):
created_time
=
time
.
time
()
created_time
=
time
.
time
()
if
self
.
to_create_loop
:
self
.
auto_create_handle_loop
()
self
.
create_handle_loop
()
if
isinstance
(
obj
,
EmbeddingReqInput
)
and
self
.
is_generation
:
if
isinstance
(
obj
,
EmbeddingReqInput
)
and
self
.
is_generation
:
raise
ValueError
(
raise
ValueError
(
...
@@ -440,8 +448,7 @@ class TokenizerManager:
...
@@ -440,8 +448,7 @@ class TokenizerManager:
obj
:
UpdateWeightFromDiskReqInput
,
obj
:
UpdateWeightFromDiskReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
Tuple
[
bool
,
str
]:
)
->
Tuple
[
bool
,
str
]:
if
self
.
to_create_loop
:
self
.
auto_create_handle_loop
()
self
.
create_handle_loop
()
# default the load format to the server_args
# default the load format to the server_args
if
obj
.
load_format
is
None
:
if
obj
.
load_format
is
None
:
...
@@ -456,7 +463,7 @@ class TokenizerManager:
...
@@ -456,7 +463,7 @@ class TokenizerManager:
async
def
_wait_for_model_update_from_disk
(
async
def
_wait_for_model_update_from_disk
(
self
,
obj
:
UpdateWeightFromDiskReqInput
self
,
obj
:
UpdateWeightFromDiskReqInput
)
->
Tuple
[
bool
,
str
,
int
]:
)
->
Tuple
[
bool
,
str
]:
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
model_update_result
=
asyncio
.
Future
()
self
.
model_update_result
=
asyncio
.
Future
()
if
self
.
server_args
.
dp_size
==
1
:
if
self
.
server_args
.
dp_size
==
1
:
...
@@ -485,15 +492,11 @@ class TokenizerManager:
...
@@ -485,15 +492,11 @@ class TokenizerManager:
obj
:
InitWeightsUpdateGroupReqInput
,
obj
:
InitWeightsUpdateGroupReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
Tuple
[
bool
,
str
]:
)
->
Tuple
[
bool
,
str
]:
if
self
.
to_create_loop
:
self
.
auto_create_handle_loop
()
self
.
create_handle_loop
()
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
init_weights_update_group_result
=
asyncio
.
Future
()
assert
(
assert
(
self
.
server_args
.
dp_size
==
1
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for init parameter update group"
),
"dp_size must be 1 for init parameter update group"
result
=
await
self
.
init_weights_update_group_
result
result
=
(
await
self
.
init_weights_update_group_
communicator
(
obj
))[
0
]
return
result
.
success
,
result
.
message
return
result
.
success
,
result
.
message
async
def
update_weights_from_distributed
(
async
def
update_weights_from_distributed
(
...
@@ -501,44 +504,32 @@ class TokenizerManager:
...
@@ -501,44 +504,32 @@ class TokenizerManager:
obj
:
UpdateWeightsFromDistributedReqInput
,
obj
:
UpdateWeightsFromDistributedReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
Tuple
[
bool
,
str
]:
)
->
Tuple
[
bool
,
str
]:
if
self
.
to_create_loop
:
self
.
auto_create_handle_loop
()
self
.
create_handle_loop
()
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be for update weights from distributed"
# This means that weight sync
# This means that weight sync
# cannot run while requests are in progress.
# cannot run while requests are in progress.
async
with
self
.
model_update_lock
.
writer_lock
:
async
with
self
.
model_update_lock
.
writer_lock
:
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
result
=
(
await
self
.
update_weights_from_distributed_communicator
(
obj
))[
0
]
self
.
parameter_update_result
:
Awaitable
[
UpdateWeightsFromDistributedReqOutput
]
=
asyncio
.
Future
()
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be for update weights from distributed"
result
=
await
self
.
parameter_update_result
return
result
.
success
,
result
.
message
return
result
.
success
,
result
.
message
async
def
get_weights_by_name
(
async
def
get_weights_by_name
(
self
,
obj
:
GetWeightsByNameReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
self
,
obj
:
GetWeightsByNameReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
):
if
self
.
to_create_loop
:
self
.
auto_create_handle_loop
()
self
.
create_handle_loop
()
results
=
await
self
.
get_weights_by_name_communicator
(
obj
)
all_parameters
=
[
r
.
parameter
for
r
in
results
]
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
get_weights_by_name_result
=
asyncio
.
Future
()
if
self
.
server_args
.
dp_size
==
1
:
if
self
.
server_args
.
dp_size
==
1
:
result
=
await
self
.
get_weights_by_name_result
return
all_parameters
[
0
]
return
result
.
parameter
else
:
else
:
self
.
get_weights_by_name_tmp
=
[]
result
=
await
self
.
get_weights_by_name_result
all_parameters
=
[
r
.
parameter
for
r
in
result
]
return
all_parameters
return
all_parameters
async
def
open_session
(
async
def
open_session
(
self
,
obj
:
OpenSessionReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
self
,
obj
:
OpenSessionReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
):
if
self
.
to_create_loop
:
self
.
auto_create_handle_loop
()
self
.
create_handle_loop
()
session_id
=
uuid
.
uuid4
().
hex
session_id
=
uuid
.
uuid4
().
hex
obj
.
session_id
=
session_id
obj
.
session_id
=
session_id
...
@@ -568,7 +559,7 @@ class TokenizerManager:
...
@@ -568,7 +559,7 @@ class TokenizerManager:
background_tasks
.
add_task
(
abort_request
)
background_tasks
.
add_task
(
abort_request
)
return
background_tasks
return
background_tasks
def
create_handle_loop
(
self
):
def
auto_
create_handle_loop
(
self
):
if
not
self
.
to_create_loop
:
if
not
self
.
to_create_loop
:
return
return
...
@@ -711,21 +702,14 @@ class TokenizerManager:
...
@@ -711,21 +702,14 @@ class TokenizerManager:
assert
(
assert
(
self
.
server_args
.
dp_size
==
1
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for init parameter update group"
),
"dp_size must be 1 for init parameter update group"
self
.
init_weights_update_group_
result
.
set_result
(
recv_obj
)
self
.
init_weights_update_group_
communicator
.
handle_recv
(
recv_obj
)
elif
isinstance
(
recv_obj
,
UpdateWeightsFromDistributedReqOutput
):
elif
isinstance
(
recv_obj
,
UpdateWeightsFromDistributedReqOutput
):
assert
(
assert
(
self
.
server_args
.
dp_size
==
1
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for update weights from distributed"
),
"dp_size must be 1 for update weights from distributed"
self
.
parameter_update_result
.
set_result
(
recv_obj
)
self
.
update_weights_from_distributed_communicator
.
handle_recv
(
recv_obj
)
elif
isinstance
(
recv_obj
,
GetWeightsByNameReqOutput
):
elif
isinstance
(
recv_obj
,
GetWeightsByNameReqOutput
):
if
self
.
server_args
.
dp_size
==
1
:
self
.
get_weights_by_name_communicator
.
handle_recv
(
recv_obj
)
self
.
get_weights_by_name_result
.
set_result
(
recv_obj
)
else
:
self
.
get_weights_by_name_tmp
.
append
(
recv_obj
)
if
len
(
self
.
get_weights_by_name_tmp
)
==
self
.
server_args
.
dp_size
:
self
.
get_weights_by_name_result
.
set_result
(
self
.
get_weights_by_name_tmp
)
else
:
else
:
raise
ValueError
(
f
"Invalid object:
{
recv_obj
=
}
"
)
raise
ValueError
(
f
"Invalid object:
{
recv_obj
=
}
"
)
...
@@ -809,3 +793,28 @@ class SignalHandler:
...
@@ -809,3 +793,28 @@ class SignalHandler:
f
"SIGTERM received.
{
signum
=
}
{
frame
=
}
. Draining requests and shutting down..."
f
"SIGTERM received.
{
signum
=
}
{
frame
=
}
. Draining requests and shutting down..."
)
)
self
.
tokenizer_manager
.
gracefully_exit
=
True
self
.
tokenizer_manager
.
gracefully_exit
=
True
T
=
TypeVar
(
"T"
)
class
_Communicator
(
Generic
[
T
]):
def
__init__
(
self
,
sender
,
fan_out
:
int
):
self
.
_sender
=
sender
self
.
_fan_out
=
fan_out
self
.
_result_future
:
Optional
[
asyncio
.
Future
]
=
None
self
.
_result_values
:
Optional
[
List
[
T
]]
=
None
async
def
__call__
(
self
,
obj
):
self
.
_sender
.
send_pyobj
(
obj
)
self
.
_result_future
=
asyncio
.
Future
()
self
.
_result_values
=
[]
await
self
.
_result_future
result_values
=
self
.
_result_values
self
.
_result_future
=
self
.
_result_values
=
None
return
result_values
def
handle_recv
(
self
,
recv_obj
:
T
):
self
.
_result_values
.
append
(
recv_obj
)
if
len
(
self
.
_result_values
)
==
self
.
_fan_out
:
self
.
_result_future
.
set_result
(
None
)
python/sglang/srt/server.py
View file @
b2ed5c8e
...
@@ -245,16 +245,11 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
...
@@ -245,16 +245,11 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
try
:
try
:
ret
=
await
tokenizer_manager
.
get_weights_by_name
(
obj
,
request
)
ret
=
await
tokenizer_manager
.
get_weights_by_name
(
obj
,
request
)
if
ret
is
None
:
if
ret
is
None
:
return
ORJSONResponse
(
return
_create_error_response
(
"Get parameter by name failed"
)
{
"error"
:
{
"message"
:
"Get parameter by name failed"
}},
status_code
=
HTTPStatus
.
BAD_REQUEST
,
)
else
:
else
:
return
ORJSONResponse
(
ret
,
status_code
=
200
)
return
ORJSONResponse
(
ret
,
status_code
=
200
)
except
Exception
as
e
:
except
Exception
as
e
:
return
ORJSONResponse
(
return
_create_error_response
(
e
)
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
api_route
(
"/open_session"
,
methods
=
[
"GET"
,
"POST"
])
@
app
.
api_route
(
"/open_session"
,
methods
=
[
"GET"
,
"POST"
])
...
@@ -264,9 +259,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
...
@@ -264,9 +259,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
session_id
=
await
tokenizer_manager
.
open_session
(
obj
,
request
)
session_id
=
await
tokenizer_manager
.
open_session
(
obj
,
request
)
return
session_id
return
session_id
except
Exception
as
e
:
except
Exception
as
e
:
return
ORJSONResponse
(
return
_create_error_response
(
e
)
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
api_route
(
"/close_session"
,
methods
=
[
"GET"
,
"POST"
])
@
app
.
api_route
(
"/close_session"
,
methods
=
[
"GET"
,
"POST"
])
...
@@ -276,9 +269,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
...
@@ -276,9 +269,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
await
tokenizer_manager
.
close_session
(
obj
,
request
)
await
tokenizer_manager
.
close_session
(
obj
,
request
)
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
except
Exception
as
e
:
except
Exception
as
e
:
return
ORJSONResponse
(
return
_create_error_response
(
e
)
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
# fastapi implicitly converts json in the request to obj (dataclass)
# fastapi implicitly converts json in the request to obj (dataclass)
...
@@ -312,9 +303,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
...
@@ -312,9 +303,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
return
ret
return
ret
except
ValueError
as
e
:
except
ValueError
as
e
:
logger
.
error
(
f
"Error:
{
e
}
"
)
logger
.
error
(
f
"Error:
{
e
}
"
)
return
ORJSONResponse
(
return
_create_error_response
(
e
)
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
api_route
(
"/encode"
,
methods
=
[
"POST"
,
"PUT"
])
@
app
.
api_route
(
"/encode"
,
methods
=
[
"POST"
,
"PUT"
])
...
@@ -325,9 +314,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
...
@@ -325,9 +314,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
return
ret
return
ret
except
ValueError
as
e
:
except
ValueError
as
e
:
return
ORJSONResponse
(
return
_create_error_response
(
e
)
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
api_route
(
"/classify"
,
methods
=
[
"POST"
,
"PUT"
])
@
app
.
api_route
(
"/classify"
,
methods
=
[
"POST"
,
"PUT"
])
...
@@ -338,9 +325,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
...
@@ -338,9 +325,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
return
ret
return
ret
except
ValueError
as
e
:
except
ValueError
as
e
:
return
ORJSONResponse
(
return
_create_error_response
(
e
)
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
##### OpenAI-compatible API endpoints #####
##### OpenAI-compatible API endpoints #####
...
@@ -416,6 +401,12 @@ async def retrieve_file_content(file_id: str):
...
@@ -416,6 +401,12 @@ async def retrieve_file_content(file_id: str):
return
await
v1_retrieve_file_content
(
file_id
)
return
await
v1_retrieve_file_content
(
file_id
)
def
_create_error_response
(
e
):
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
def
launch_engine
(
def
launch_engine
(
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
):
):
...
@@ -849,12 +840,10 @@ class Engine:
...
@@ -849,12 +840,10 @@ class Engine:
group_name
=
group_name
,
group_name
=
group_name
,
backend
=
backend
,
backend
=
backend
,
)
)
async
def
_init_group
():
return
await
tokenizer_manager
.
init_weights_update_group
(
obj
,
None
)
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
_init_group
())
return
loop
.
run_until_complete
(
tokenizer_manager
.
init_weights_update_group
(
obj
,
None
)
)
def
update_weights_from_distributed
(
self
,
name
,
dtype
,
shape
):
def
update_weights_from_distributed
(
self
,
name
,
dtype
,
shape
):
"""Update weights from distributed source."""
"""Update weights from distributed source."""
...
@@ -863,22 +852,16 @@ class Engine:
...
@@ -863,22 +852,16 @@ class Engine:
dtype
=
dtype
,
dtype
=
dtype
,
shape
=
shape
,
shape
=
shape
,
)
)
async
def
_update_weights
():
return
await
tokenizer_manager
.
update_weights_from_distributed
(
obj
,
None
)
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
_update_weights
())
return
loop
.
run_until_complete
(
tokenizer_manager
.
update_weights_from_distributed
(
obj
,
None
)
)
def
get_weights_by_name
(
self
,
name
,
truncate_size
=
100
):
def
get_weights_by_name
(
self
,
name
,
truncate_size
=
100
):
"""Get weights by parameter name."""
"""Get weights by parameter name."""
obj
=
GetWeightsByNameReqInput
(
name
=
name
,
truncate_size
=
truncate_size
)
obj
=
GetWeightsByNameReqInput
(
name
=
name
,
truncate_size
=
truncate_size
)
async
def
_get_weights
():
return
await
tokenizer_manager
.
get_weights_by_name
(
obj
,
None
)
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
_get_weights
(
))
return
loop
.
run_until_complete
(
tokenizer_manager
.
get_weights_by_name
(
obj
,
None
))
class
Runtime
:
class
Runtime
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment