Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b2ed5c8e
"docker/diffusers-pytorch-minimum-cuda/Dockerfile" did not exist on "9d1341d69bed9b0174892008eca06af610d5f3d8"
Unverified
Commit
b2ed5c8e
authored
Dec 27, 2024
by
fzyzcjy
Committed by
GitHub
Dec 26, 2024
Browse files
Tiny code cleanup in tokenizer_manager.py (#2586)
parent
f46f394f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
82 deletions
+74
-82
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+54
-45
python/sglang/srt/server.py
python/sglang/srt/server.py
+20
-37
No files found.
python/sglang/srt/managers/tokenizer_manager.py
View file @
b2ed5c8e
...
@@ -22,7 +22,7 @@ import signal
...
@@ -22,7 +22,7 @@ import signal
import
sys
import
sys
import
time
import
time
import
uuid
import
uuid
from
typing
import
Any
,
Awaitable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Awaitable
,
Dict
,
Generic
,
List
,
Optional
,
Tuple
,
TypeVar
,
Union
import
fastapi
import
fastapi
import
uvloop
import
uvloop
...
@@ -173,6 +173,15 @@ class TokenizerManager:
...
@@ -173,6 +173,15 @@ class TokenizerManager:
# Others
# Others
self
.
gracefully_exit
=
False
self
.
gracefully_exit
=
False
self
.
init_weights_update_group_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
update_weights_from_distributed_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
get_weights_by_name_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
# Metrics
# Metrics
if
self
.
enable_metrics
:
if
self
.
enable_metrics
:
...
@@ -190,8 +199,7 @@ class TokenizerManager:
...
@@ -190,8 +199,7 @@ class TokenizerManager:
):
):
created_time
=
time
.
time
()
created_time
=
time
.
time
()
if
self
.
to_create_loop
:
self
.
auto_create_handle_loop
()
self
.
create_handle_loop
()
if
isinstance
(
obj
,
EmbeddingReqInput
)
and
self
.
is_generation
:
if
isinstance
(
obj
,
EmbeddingReqInput
)
and
self
.
is_generation
:
raise
ValueError
(
raise
ValueError
(
...
@@ -440,8 +448,7 @@ class TokenizerManager:
...
@@ -440,8 +448,7 @@ class TokenizerManager:
obj
:
UpdateWeightFromDiskReqInput
,
obj
:
UpdateWeightFromDiskReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
Tuple
[
bool
,
str
]:
)
->
Tuple
[
bool
,
str
]:
if
self
.
to_create_loop
:
self
.
auto_create_handle_loop
()
self
.
create_handle_loop
()
# default the load format to the server_args
# default the load format to the server_args
if
obj
.
load_format
is
None
:
if
obj
.
load_format
is
None
:
...
@@ -456,7 +463,7 @@ class TokenizerManager:
...
@@ -456,7 +463,7 @@ class TokenizerManager:
async
def
_wait_for_model_update_from_disk
(
async
def
_wait_for_model_update_from_disk
(
self
,
obj
:
UpdateWeightFromDiskReqInput
self
,
obj
:
UpdateWeightFromDiskReqInput
)
->
Tuple
[
bool
,
str
,
int
]:
)
->
Tuple
[
bool
,
str
]:
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
model_update_result
=
asyncio
.
Future
()
self
.
model_update_result
=
asyncio
.
Future
()
if
self
.
server_args
.
dp_size
==
1
:
if
self
.
server_args
.
dp_size
==
1
:
...
@@ -485,15 +492,11 @@ class TokenizerManager:
...
@@ -485,15 +492,11 @@ class TokenizerManager:
obj
:
InitWeightsUpdateGroupReqInput
,
obj
:
InitWeightsUpdateGroupReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
Tuple
[
bool
,
str
]:
)
->
Tuple
[
bool
,
str
]:
if
self
.
to_create_loop
:
self
.
auto_create_handle_loop
()
self
.
create_handle_loop
()
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
init_weights_update_group_result
=
asyncio
.
Future
()
assert
(
assert
(
self
.
server_args
.
dp_size
==
1
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for init parameter update group"
),
"dp_size must be 1 for init parameter update group"
result
=
await
self
.
init_weights_update_group_
result
result
=
(
await
self
.
init_weights_update_group_
communicator
(
obj
))[
0
]
return
result
.
success
,
result
.
message
return
result
.
success
,
result
.
message
async
def
update_weights_from_distributed
(
async
def
update_weights_from_distributed
(
...
@@ -501,44 +504,32 @@ class TokenizerManager:
...
@@ -501,44 +504,32 @@ class TokenizerManager:
obj
:
UpdateWeightsFromDistributedReqInput
,
obj
:
UpdateWeightsFromDistributedReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
Tuple
[
bool
,
str
]:
)
->
Tuple
[
bool
,
str
]:
if
self
.
to_create_loop
:
self
.
auto_create_handle_loop
()
self
.
create_handle_loop
()
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be for update weights from distributed"
# This means that weight sync
# This means that weight sync
# cannot run while requests are in progress.
# cannot run while requests are in progress.
async
with
self
.
model_update_lock
.
writer_lock
:
async
with
self
.
model_update_lock
.
writer_lock
:
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
result
=
(
await
self
.
update_weights_from_distributed_communicator
(
obj
))[
0
]
self
.
parameter_update_result
:
Awaitable
[
UpdateWeightsFromDistributedReqOutput
]
=
asyncio
.
Future
()
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be for update weights from distributed"
result
=
await
self
.
parameter_update_result
return
result
.
success
,
result
.
message
return
result
.
success
,
result
.
message
async
def
get_weights_by_name
(
async
def
get_weights_by_name
(
self
,
obj
:
GetWeightsByNameReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
self
,
obj
:
GetWeightsByNameReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
):
if
self
.
to_create_loop
:
self
.
auto_create_handle_loop
()
self
.
create_handle_loop
()
results
=
await
self
.
get_weights_by_name_communicator
(
obj
)
all_parameters
=
[
r
.
parameter
for
r
in
results
]
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
get_weights_by_name_result
=
asyncio
.
Future
()
if
self
.
server_args
.
dp_size
==
1
:
if
self
.
server_args
.
dp_size
==
1
:
result
=
await
self
.
get_weights_by_name_result
return
all_parameters
[
0
]
return
result
.
parameter
else
:
else
:
self
.
get_weights_by_name_tmp
=
[]
result
=
await
self
.
get_weights_by_name_result
all_parameters
=
[
r
.
parameter
for
r
in
result
]
return
all_parameters
return
all_parameters
async
def
open_session
(
async
def
open_session
(
self
,
obj
:
OpenSessionReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
self
,
obj
:
OpenSessionReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
):
if
self
.
to_create_loop
:
self
.
auto_create_handle_loop
()
self
.
create_handle_loop
()
session_id
=
uuid
.
uuid4
().
hex
session_id
=
uuid
.
uuid4
().
hex
obj
.
session_id
=
session_id
obj
.
session_id
=
session_id
...
@@ -568,7 +559,7 @@ class TokenizerManager:
...
@@ -568,7 +559,7 @@ class TokenizerManager:
background_tasks
.
add_task
(
abort_request
)
background_tasks
.
add_task
(
abort_request
)
return
background_tasks
return
background_tasks
def
create_handle_loop
(
self
):
def
auto_
create_handle_loop
(
self
):
if
not
self
.
to_create_loop
:
if
not
self
.
to_create_loop
:
return
return
...
@@ -711,21 +702,14 @@ class TokenizerManager:
...
@@ -711,21 +702,14 @@ class TokenizerManager:
assert
(
assert
(
self
.
server_args
.
dp_size
==
1
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for init parameter update group"
),
"dp_size must be 1 for init parameter update group"
self
.
init_weights_update_group_
result
.
set_result
(
recv_obj
)
self
.
init_weights_update_group_
communicator
.
handle_recv
(
recv_obj
)
elif
isinstance
(
recv_obj
,
UpdateWeightsFromDistributedReqOutput
):
elif
isinstance
(
recv_obj
,
UpdateWeightsFromDistributedReqOutput
):
assert
(
assert
(
self
.
server_args
.
dp_size
==
1
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for update weights from distributed"
),
"dp_size must be 1 for update weights from distributed"
self
.
parameter_update_result
.
set_result
(
recv_obj
)
self
.
update_weights_from_distributed_communicator
.
handle_recv
(
recv_obj
)
elif
isinstance
(
recv_obj
,
GetWeightsByNameReqOutput
):
elif
isinstance
(
recv_obj
,
GetWeightsByNameReqOutput
):
if
self
.
server_args
.
dp_size
==
1
:
self
.
get_weights_by_name_communicator
.
handle_recv
(
recv_obj
)
self
.
get_weights_by_name_result
.
set_result
(
recv_obj
)
else
:
self
.
get_weights_by_name_tmp
.
append
(
recv_obj
)
if
len
(
self
.
get_weights_by_name_tmp
)
==
self
.
server_args
.
dp_size
:
self
.
get_weights_by_name_result
.
set_result
(
self
.
get_weights_by_name_tmp
)
else
:
else
:
raise
ValueError
(
f
"Invalid object:
{
recv_obj
=
}
"
)
raise
ValueError
(
f
"Invalid object:
{
recv_obj
=
}
"
)
...
@@ -809,3 +793,28 @@ class SignalHandler:
...
@@ -809,3 +793,28 @@ class SignalHandler:
f
"SIGTERM received.
{
signum
=
}
{
frame
=
}
. Draining requests and shutting down..."
f
"SIGTERM received.
{
signum
=
}
{
frame
=
}
. Draining requests and shutting down..."
)
)
self
.
tokenizer_manager
.
gracefully_exit
=
True
self
.
tokenizer_manager
.
gracefully_exit
=
True
T
=
TypeVar
(
"T"
)
class
_Communicator
(
Generic
[
T
]):
def
__init__
(
self
,
sender
,
fan_out
:
int
):
self
.
_sender
=
sender
self
.
_fan_out
=
fan_out
self
.
_result_future
:
Optional
[
asyncio
.
Future
]
=
None
self
.
_result_values
:
Optional
[
List
[
T
]]
=
None
async
def
__call__
(
self
,
obj
):
self
.
_sender
.
send_pyobj
(
obj
)
self
.
_result_future
=
asyncio
.
Future
()
self
.
_result_values
=
[]
await
self
.
_result_future
result_values
=
self
.
_result_values
self
.
_result_future
=
self
.
_result_values
=
None
return
result_values
def
handle_recv
(
self
,
recv_obj
:
T
):
self
.
_result_values
.
append
(
recv_obj
)
if
len
(
self
.
_result_values
)
==
self
.
_fan_out
:
self
.
_result_future
.
set_result
(
None
)
python/sglang/srt/server.py
View file @
b2ed5c8e
...
@@ -245,16 +245,11 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
...
@@ -245,16 +245,11 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
try
:
try
:
ret
=
await
tokenizer_manager
.
get_weights_by_name
(
obj
,
request
)
ret
=
await
tokenizer_manager
.
get_weights_by_name
(
obj
,
request
)
if
ret
is
None
:
if
ret
is
None
:
return
ORJSONResponse
(
return
_create_error_response
(
"Get parameter by name failed"
)
{
"error"
:
{
"message"
:
"Get parameter by name failed"
}},
status_code
=
HTTPStatus
.
BAD_REQUEST
,
)
else
:
else
:
return
ORJSONResponse
(
ret
,
status_code
=
200
)
return
ORJSONResponse
(
ret
,
status_code
=
200
)
except
Exception
as
e
:
except
Exception
as
e
:
return
ORJSONResponse
(
return
_create_error_response
(
e
)
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
api_route
(
"/open_session"
,
methods
=
[
"GET"
,
"POST"
])
@
app
.
api_route
(
"/open_session"
,
methods
=
[
"GET"
,
"POST"
])
...
@@ -264,9 +259,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
...
@@ -264,9 +259,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
session_id
=
await
tokenizer_manager
.
open_session
(
obj
,
request
)
session_id
=
await
tokenizer_manager
.
open_session
(
obj
,
request
)
return
session_id
return
session_id
except
Exception
as
e
:
except
Exception
as
e
:
return
ORJSONResponse
(
return
_create_error_response
(
e
)
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
api_route
(
"/close_session"
,
methods
=
[
"GET"
,
"POST"
])
@
app
.
api_route
(
"/close_session"
,
methods
=
[
"GET"
,
"POST"
])
...
@@ -276,9 +269,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
...
@@ -276,9 +269,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
await
tokenizer_manager
.
close_session
(
obj
,
request
)
await
tokenizer_manager
.
close_session
(
obj
,
request
)
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
except
Exception
as
e
:
except
Exception
as
e
:
return
ORJSONResponse
(
return
_create_error_response
(
e
)
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
# fastapi implicitly converts json in the request to obj (dataclass)
# fastapi implicitly converts json in the request to obj (dataclass)
...
@@ -312,9 +303,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
...
@@ -312,9 +303,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
return
ret
return
ret
except
ValueError
as
e
:
except
ValueError
as
e
:
logger
.
error
(
f
"Error:
{
e
}
"
)
logger
.
error
(
f
"Error:
{
e
}
"
)
return
ORJSONResponse
(
return
_create_error_response
(
e
)
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
api_route
(
"/encode"
,
methods
=
[
"POST"
,
"PUT"
])
@
app
.
api_route
(
"/encode"
,
methods
=
[
"POST"
,
"PUT"
])
...
@@ -325,9 +314,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
...
@@ -325,9 +314,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
return
ret
return
ret
except
ValueError
as
e
:
except
ValueError
as
e
:
return
ORJSONResponse
(
return
_create_error_response
(
e
)
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
api_route
(
"/classify"
,
methods
=
[
"POST"
,
"PUT"
])
@
app
.
api_route
(
"/classify"
,
methods
=
[
"POST"
,
"PUT"
])
...
@@ -338,9 +325,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
...
@@ -338,9 +325,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
return
ret
return
ret
except
ValueError
as
e
:
except
ValueError
as
e
:
return
ORJSONResponse
(
return
_create_error_response
(
e
)
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
##### OpenAI-compatible API endpoints #####
##### OpenAI-compatible API endpoints #####
...
@@ -416,6 +401,12 @@ async def retrieve_file_content(file_id: str):
...
@@ -416,6 +401,12 @@ async def retrieve_file_content(file_id: str):
return
await
v1_retrieve_file_content
(
file_id
)
return
await
v1_retrieve_file_content
(
file_id
)
def
_create_error_response
(
e
):
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
def
launch_engine
(
def
launch_engine
(
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
):
):
...
@@ -849,12 +840,10 @@ class Engine:
...
@@ -849,12 +840,10 @@ class Engine:
group_name
=
group_name
,
group_name
=
group_name
,
backend
=
backend
,
backend
=
backend
,
)
)
async
def
_init_group
():
return
await
tokenizer_manager
.
init_weights_update_group
(
obj
,
None
)
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
_init_group
())
return
loop
.
run_until_complete
(
tokenizer_manager
.
init_weights_update_group
(
obj
,
None
)
)
def
update_weights_from_distributed
(
self
,
name
,
dtype
,
shape
):
def
update_weights_from_distributed
(
self
,
name
,
dtype
,
shape
):
"""Update weights from distributed source."""
"""Update weights from distributed source."""
...
@@ -863,22 +852,16 @@ class Engine:
...
@@ -863,22 +852,16 @@ class Engine:
dtype
=
dtype
,
dtype
=
dtype
,
shape
=
shape
,
shape
=
shape
,
)
)
async
def
_update_weights
():
return
await
tokenizer_manager
.
update_weights_from_distributed
(
obj
,
None
)
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
_update_weights
())
return
loop
.
run_until_complete
(
tokenizer_manager
.
update_weights_from_distributed
(
obj
,
None
)
)
def
get_weights_by_name
(
self
,
name
,
truncate_size
=
100
):
def
get_weights_by_name
(
self
,
name
,
truncate_size
=
100
):
"""Get weights by parameter name."""
"""Get weights by parameter name."""
obj
=
GetWeightsByNameReqInput
(
name
=
name
,
truncate_size
=
truncate_size
)
obj
=
GetWeightsByNameReqInput
(
name
=
name
,
truncate_size
=
truncate_size
)
async
def
_get_weights
():
return
await
tokenizer_manager
.
get_weights_by_name
(
obj
,
None
)
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
_get_weights
(
))
return
loop
.
run_until_complete
(
tokenizer_manager
.
get_weights_by_name
(
obj
,
None
))
class
Runtime
:
class
Runtime
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment