Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
8210ec60
Unverified
Commit
8210ec60
authored
May 17, 2024
by
Lianmin Zheng
Committed by
GitHub
May 17, 2024
Browse files
Improve error handling & abort disconnected requests (#449)
parent
5be9eb8a
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
198 additions
and
126 deletions
+198
-126
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+18
-10
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+7
-3
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+10
-7
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+49
-31
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+88
-58
python/sglang/srt/openai_api_adapter.py
python/sglang/srt/openai_api_adapter.py
+1
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+18
-13
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-1
python/sglang/utils.py
python/sglang/utils.py
+6
-2
No files found.
python/sglang/backend/runtime_endpoint.py
View file @
8210ec60
...
@@ -34,7 +34,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -34,7 +34,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
self
.
model_info
=
res
.
json
()
self
.
model_info
=
res
.
json
()
self
.
chat_template
=
get_chat_template_by_model_path
(
self
.
chat_template
=
get_chat_template_by_model_path
(
...
@@ -50,7 +50,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -50,7 +50,7 @@ class RuntimeEndpoint(BaseBackend):
auth_token
=
self
.
auth_token
,
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
return
res
.
status_code
==
200
self
.
_assert_success
(
res
)
def
get_server_args
(
self
):
def
get_server_args
(
self
):
res
=
http_request
(
res
=
http_request
(
...
@@ -58,6 +58,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -58,6 +58,7 @@ class RuntimeEndpoint(BaseBackend):
auth_token
=
self
.
auth_token
,
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
self
.
_assert_success
(
res
)
return
res
.
json
()
return
res
.
json
()
def
get_chat_template
(
self
):
def
get_chat_template
(
self
):
...
@@ -71,7 +72,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -71,7 +72,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
def
commit_lazy_operations
(
self
,
s
:
StreamExecutor
):
def
commit_lazy_operations
(
self
,
s
:
StreamExecutor
):
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
...
@@ -83,7 +84,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -83,7 +84,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
def
fill_image
(
self
,
s
:
StreamExecutor
):
def
fill_image
(
self
,
s
:
StreamExecutor
):
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
...
@@ -95,7 +96,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -95,7 +96,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
def
generate
(
def
generate
(
self
,
self
,
...
@@ -133,6 +134,8 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -133,6 +134,8 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
self
.
_assert_success
(
res
)
obj
=
res
.
json
()
obj
=
res
.
json
()
comp
=
obj
[
"text"
]
comp
=
obj
[
"text"
]
return
comp
,
obj
[
"meta_info"
]
return
comp
,
obj
[
"meta_info"
]
...
@@ -167,7 +170,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -167,7 +170,7 @@ class RuntimeEndpoint(BaseBackend):
data
[
"stream"
]
=
True
data
[
"stream"
]
=
True
self
.
_add_images
(
s
,
data
)
self
.
_add_images
(
s
,
data
)
res
ponse
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
data
,
json
=
data
,
stream
=
True
,
stream
=
True
,
...
@@ -175,10 +178,11 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -175,10 +178,11 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
self
.
_assert_success
(
res
)
pos
=
0
pos
=
0
incomplete_text
=
""
incomplete_text
=
""
for
chunk
in
res
ponse
.
iter_lines
(
decode_unicode
=
False
):
for
chunk
in
res
.
iter_lines
(
decode_unicode
=
False
):
chunk
=
chunk
.
decode
(
"utf-8"
)
chunk
=
chunk
.
decode
(
"utf-8"
)
if
chunk
and
chunk
.
startswith
(
"data:"
):
if
chunk
and
chunk
.
startswith
(
"data:"
):
if
chunk
==
"data: [DONE]"
:
if
chunk
==
"data: [DONE]"
:
...
@@ -211,7 +215,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -211,7 +215,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
prompt_len
=
res
.
json
()[
"meta_info"
][
"prompt_tokens"
]
prompt_len
=
res
.
json
()[
"meta_info"
][
"prompt_tokens"
]
# Compute logprob
# Compute logprob
...
@@ -229,7 +233,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -229,7 +233,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
obj
=
res
.
json
()
obj
=
res
.
json
()
normalized_prompt_logprobs
=
[
normalized_prompt_logprobs
=
[
r
[
"meta_info"
][
"normalized_prompt_logprob"
]
for
r
in
obj
r
[
"meta_info"
][
"normalized_prompt_logprob"
]
for
r
in
obj
...
@@ -253,9 +257,13 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -253,9 +257,13 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
def
_add_images
(
self
,
s
:
StreamExecutor
,
data
):
def
_add_images
(
self
,
s
:
StreamExecutor
,
data
):
if
s
.
images_
:
if
s
.
images_
:
assert
len
(
s
.
images_
)
==
1
,
"Only support one image."
assert
len
(
s
.
images_
)
==
1
,
"Only support one image."
data
[
"image_data"
]
=
s
.
images_
[
0
][
1
]
data
[
"image_data"
]
=
s
.
images_
[
0
][
1
]
def
_assert_success
(
self
,
res
):
if
res
.
status_code
!=
200
:
raise
RuntimeError
(
res
.
json
())
\ No newline at end of file
python/sglang/lang/interpreter.py
View file @
8210ec60
...
@@ -191,7 +191,7 @@ class StreamExecutor:
...
@@ -191,7 +191,7 @@ class StreamExecutor:
self
.
variable_event
=
{}
# Dict[name: str -> event: threading.Event]
self
.
variable_event
=
{}
# Dict[name: str -> event: threading.Event]
self
.
meta_info
=
{}
# Dict[name: str -> info: str]
self
.
meta_info
=
{}
# Dict[name: str -> info: str]
self
.
is_finished
=
False
self
.
is_finished
=
False
self
.
error
=
None
self
.
error
_
=
None
# For completion
# For completion
self
.
text_
=
""
# The full text
self
.
text_
=
""
# The full text
...
@@ -300,6 +300,10 @@ class StreamExecutor:
...
@@ -300,6 +300,10 @@ class StreamExecutor:
self
.
sync
()
self
.
sync
()
return
self
.
messages_
return
self
.
messages_
def
error
(
self
):
self
.
sync
()
return
self
.
error_
def
end
(
self
):
def
end
(
self
):
if
self
.
use_thread
:
if
self
.
use_thread
:
if
self
.
worker
.
is_alive
():
if
self
.
worker
.
is_alive
():
...
@@ -338,7 +342,7 @@ class StreamExecutor:
...
@@ -338,7 +342,7 @@ class StreamExecutor:
if
self
.
stream_var_event
:
if
self
.
stream_var_event
:
for
name
in
self
.
stream_var_event
:
for
name
in
self
.
stream_var_event
:
self
.
stream_var_event
[
name
].
set
()
self
.
stream_var_event
[
name
].
set
()
self
.
error
=
error
self
.
error
_
=
error
if
self
.
stream_text_event
:
if
self
.
stream_text_event
:
self
.
stream_text_event
.
set
()
self
.
stream_text_event
.
set
()
...
@@ -713,7 +717,7 @@ class ProgramState:
...
@@ -713,7 +717,7 @@ class ProgramState:
return
self
.
stream_executor
.
sync
()
return
self
.
stream_executor
.
sync
()
def
error
(
self
):
def
error
(
self
):
return
self
.
stream_executor
.
error
return
self
.
stream_executor
.
error
()
def
text_iter
(
self
,
var_name
:
Optional
[
str
]
=
None
):
def
text_iter
(
self
,
var_name
:
Optional
[
str
]
=
None
):
if
self
.
stream_executor
.
stream
:
if
self
.
stream_executor
.
stream
:
...
...
python/sglang/srt/managers/io_struct.py
View file @
8210ec60
...
@@ -31,12 +31,9 @@ class GenerateReqInput:
...
@@ -31,12 +31,9 @@ class GenerateReqInput:
def
post_init
(
self
):
def
post_init
(
self
):
if
self
.
text
is
None
:
if
((
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
assert
(
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
)):
self
.
input_ids
is
not
None
raise
ValueError
(
"Either text or input_ids should be provided."
)
),
"Either text or input_ids should be provided"
else
:
assert
self
.
input_ids
is
None
,
"Either text or input_ids should be provided"
if
self
.
text
is
not
None
:
if
self
.
text
is
not
None
:
is_single
=
isinstance
(
self
.
text
,
str
)
is_single
=
isinstance
(
self
.
text
,
str
)
...
@@ -71,7 +68,8 @@ class GenerateReqInput:
...
@@ -71,7 +68,8 @@ class GenerateReqInput:
if
self
.
rid
is
None
:
if
self
.
rid
is
None
:
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
num
)]
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
num
)]
else
:
else
:
assert
isinstance
(
self
.
rid
,
list
)
if
not
isinstance
(
self
.
rid
,
list
):
raise
ValueError
(
"The rid should be a list."
)
if
self
.
return_logprob
is
None
:
if
self
.
return_logprob
is
None
:
self
.
return_logprob
=
[
False
]
*
num
self
.
return_logprob
=
[
False
]
*
num
...
@@ -129,6 +127,11 @@ class FlushCacheReq:
...
@@ -129,6 +127,11 @@ class FlushCacheReq:
pass
pass
@
dataclass
class
AbortReq
:
rid
:
str
@
dataclass
@
dataclass
class
DetokenizeReqInput
:
class
DetokenizeReqInput
:
input_ids
:
List
[
int
]
input_ids
:
List
[
int
]
python/sglang/srt/managers/router/model_rpc.py
View file @
8210ec60
...
@@ -20,6 +20,7 @@ from sglang.srt.constrained.fsm_cache import FSMCache
...
@@ -20,6 +20,7 @@ from sglang.srt.constrained.fsm_cache import FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchTokenIDOut
,
BatchTokenIDOut
,
FlushCacheReq
,
FlushCacheReq
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
...
@@ -110,6 +111,8 @@ class ModelRpcServer:
...
@@ -110,6 +111,8 @@ class ModelRpcServer:
get_int_token_logit_bias
(
self
.
tokenizer
,
self
.
model_config
.
vocab_size
)
get_int_token_logit_bias
(
self
.
tokenizer
,
self
.
model_config
.
vocab_size
)
)
)
set_random_seed
(
server_args
.
random_seed
)
set_random_seed
(
server_args
.
random_seed
)
# Print info
logger
.
info
(
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: "
f
"Rank
{
self
.
tp_rank
}
: "
f
"max_total_num_token=
{
self
.
max_total_num_token
}
, "
f
"max_total_num_token=
{
self
.
max_total_num_token
}
, "
...
@@ -160,24 +163,6 @@ class ModelRpcServer:
...
@@ -160,24 +163,6 @@ class ModelRpcServer:
self
.
min_new_token_ratio
=
min
(
0.2
*
server_args
.
schedule_conservativeness
,
1.0
)
self
.
min_new_token_ratio
=
min
(
0.2
*
server_args
.
schedule_conservativeness
,
1.0
)
self
.
new_token_ratio_step
=
(
0.0001
,
0.05
)
# (down, up)
self
.
new_token_ratio_step
=
(
0.0001
,
0.05
)
# (down, up)
def
flush_cache
(
self
):
if
len
(
self
.
forward_queue
)
==
0
and
(
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
):
self
.
tree_cache
.
reset
()
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
regex_fsm_cache
.
reset
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
logger
.
info
(
"Cache flushed successfully!"
)
else
:
warnings
.
warn
(
f
"Cache not flushed because there are pending requests. "
f
"#queue-req:
{
len
(
self
.
forward_queue
)
}
, "
f
"#running-req:
{
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
}
"
)
def
exposed_step
(
self
,
recv_reqs
):
def
exposed_step
(
self
,
recv_reqs
):
if
self
.
tp_size
!=
1
:
if
self
.
tp_size
!=
1
:
recv_reqs
=
obtain
(
recv_reqs
)
recv_reqs
=
obtain
(
recv_reqs
)
...
@@ -189,6 +174,8 @@ class ModelRpcServer:
...
@@ -189,6 +174,8 @@ class ModelRpcServer:
self
.
handle_generate_request
(
recv_req
)
self
.
handle_generate_request
(
recv_req
)
elif
isinstance
(
recv_req
,
FlushCacheReq
):
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
self
.
flush_cache
()
elif
isinstance
(
recv_req
,
AbortReq
):
self
.
abort_request
(
recv_req
)
else
:
else
:
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
...
@@ -207,9 +194,8 @@ class ModelRpcServer:
...
@@ -207,9 +194,8 @@ class ModelRpcServer:
new_batch
=
self
.
get_new_fill_batch
()
new_batch
=
self
.
get_new_fill_batch
()
if
new_batch
is
not
None
:
if
new_batch
is
not
None
:
# Run new fill batch
# Run
a
new fill batch
self
.
forward_fill_batch
(
new_batch
)
self
.
forward_fill_batch
(
new_batch
)
self
.
cache_filled_batch
(
new_batch
)
self
.
cache_filled_batch
(
new_batch
)
if
not
new_batch
.
is_empty
():
if
not
new_batch
.
is_empty
():
...
@@ -225,14 +211,8 @@ class ModelRpcServer:
...
@@ -225,14 +211,8 @@ class ModelRpcServer:
self
.
num_generated_tokens
+=
len
(
self
.
running_batch
.
reqs
)
self
.
num_generated_tokens
+=
len
(
self
.
running_batch
.
reqs
)
self
.
forward_decode_batch
(
self
.
running_batch
)
self
.
forward_decode_batch
(
self
.
running_batch
)
if
self
.
running_batch
.
is_empty
():
# Print stats
self
.
running_batch
=
None
if
self
.
tp_rank
==
0
:
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
reqs
[
0
].
stream
:
break
if
self
.
running_batch
is
not
None
and
self
.
tp_rank
==
0
:
if
self
.
decode_forward_ct
%
40
==
0
:
if
self
.
decode_forward_ct
%
40
==
0
:
num_used
=
self
.
max_total_num_token
-
(
num_used
=
self
.
max_total_num_token
-
(
self
.
token_to_kv_pool
.
available_size
()
self
.
token_to_kv_pool
.
available_size
()
...
@@ -250,8 +230,15 @@ class ModelRpcServer:
...
@@ -250,8 +230,15 @@ class ModelRpcServer:
f
"gen throughput (token/s):
{
throuhgput
:.
2
f
}
, "
f
"gen throughput (token/s):
{
throuhgput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
forward_queue
)
}
"
f
"#queue-req:
{
len
(
self
.
forward_queue
)
}
"
)
)
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
reqs
[
0
].
stream
:
break
else
:
else
:
#
c
heck the available size
#
C
heck the available size
available_size
=
(
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
+
self
.
tree_cache
.
evictable_size
()
...
@@ -295,7 +282,7 @@ class ModelRpcServer:
...
@@ -295,7 +282,7 @@ class ModelRpcServer:
req
.
sampling_params
.
regex
req
.
sampling_params
.
regex
)
)
# Truncate
long
prompts
# Truncate prompts
that are too long
req
.
input_ids
=
req
.
input_ids
[:
self
.
model_config
.
context_len
-
1
]
req
.
input_ids
=
req
.
input_ids
[:
self
.
model_config
.
context_len
-
1
]
req
.
sampling_params
.
max_new_tokens
=
min
(
req
.
sampling_params
.
max_new_tokens
=
min
(
req
.
sampling_params
.
max_new_tokens
,
req
.
sampling_params
.
max_new_tokens
,
...
@@ -311,6 +298,7 @@ class ModelRpcServer:
...
@@ -311,6 +298,7 @@ class ModelRpcServer:
):
):
return
None
return
None
# Compute matched prefix length
for
req
in
self
.
forward_queue
:
for
req
in
self
.
forward_queue
:
prefix_indices
,
last_node
=
self
.
tree_cache
.
match_prefix
(
req
.
input_ids
)
prefix_indices
,
last_node
=
self
.
tree_cache
.
match_prefix
(
req
.
input_ids
)
if
req
.
return_logprob
:
if
req
.
return_logprob
:
...
@@ -383,6 +371,7 @@ class ModelRpcServer:
...
@@ -383,6 +371,7 @@ class ModelRpcServer:
if
len
(
can_run_list
)
==
0
:
if
len
(
can_run_list
)
==
0
:
return
None
return
None
# Print stats
if
self
.
tp_rank
==
0
:
if
self
.
tp_rank
==
0
:
running_req
=
(
running_req
=
(
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
...
@@ -410,6 +399,7 @@ class ModelRpcServer:
...
@@ -410,6 +399,7 @@ class ModelRpcServer:
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
# )
# )
# Return the new batch
new_batch
=
Batch
.
init_new
(
new_batch
=
Batch
.
init_new
(
can_run_list
,
can_run_list
,
self
.
req_to_token_pool
,
self
.
req_to_token_pool
,
...
@@ -487,7 +477,7 @@ class ModelRpcServer:
...
@@ -487,7 +477,7 @@ class ModelRpcServer:
self
.
handle_finished_requests
(
batch
)
self
.
handle_finished_requests
(
batch
)
def
cache_filled_batch
(
self
,
batch
:
Batch
):
def
cache_filled_batch
(
self
,
batch
:
Batch
):
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
cpu
().
tolist
()
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
cpu
().
numpy
()
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
new_prefix_indices
,
new_last_node
=
self
.
tree_cache
.
cache_req
(
new_prefix_indices
,
new_last_node
=
self
.
tree_cache
.
cache_req
(
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
],
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
],
...
@@ -671,6 +661,34 @@ class ModelRpcServer:
...
@@ -671,6 +661,34 @@ class ModelRpcServer:
else
:
else
:
batch
.
reqs
=
[]
batch
.
reqs
=
[]
def
flush_cache
(
self
):
if
len
(
self
.
forward_queue
)
==
0
and
(
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
):
self
.
tree_cache
.
reset
()
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
regex_fsm_cache
.
reset
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
logger
.
info
(
"Cache flushed successfully!"
)
else
:
warnings
.
warn
(
f
"Cache not flushed because there are pending requests. "
f
"#queue-req:
{
len
(
self
.
forward_queue
)
}
, "
f
"#running-req:
{
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
}
"
)
def
abort_request
(
self
,
recv_req
):
to_del
=
None
for
i
,
req
in
enumerate
(
self
.
forward_queue
):
if
req
.
rid
==
recv_req
.
rid
:
to_del
=
i
break
if
to_del
is
not
None
:
del
self
.
forward_queue
[
to_del
]
class
ModelRpcService
(
rpyc
.
Service
):
class
ModelRpcService
(
rpyc
.
Service
):
exposed_ModelRpcServer
=
ModelRpcServer
exposed_ModelRpcServer
=
ModelRpcServer
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
8210ec60
...
@@ -19,6 +19,7 @@ from sglang.srt.hf_transformers_utils import (
...
@@ -19,6 +19,7 @@ from sglang.srt.hf_transformers_utils import (
get_tokenizer
,
get_tokenizer
,
)
)
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchStrOut
,
BatchStrOut
,
FlushCacheReq
,
FlushCacheReq
,
GenerateReqInput
,
GenerateReqInput
,
...
@@ -42,52 +43,6 @@ class ReqState:
...
@@ -42,52 +43,6 @@ class ReqState:
event
:
asyncio
.
Event
event
:
asyncio
.
Event
global
global_processor
def
init_global_processor
(
server_args
:
ServerArgs
):
global
global_processor
transformers
.
logging
.
set_verbosity_error
()
global_processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
def
get_pixel_values
(
image_data
,
image_aspect_ratio
=
None
,
image_grid_pinpoints
=
None
,
processor
=
None
):
try
:
processor
=
processor
or
global_processor
image
,
image_size
=
load_image
(
image_data
)
if
image_size
!=
None
:
image_hash
=
hash
(
image_data
)
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
]
for
_
in
range
(
len
(
pixel_values
)):
pixel_values
[
_
]
=
pixel_values
[
_
].
astype
(
np
.
float16
)
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
return
pixel_values
,
image_hash
,
image_size
else
:
image_hash
=
hash
(
image_data
)
if
image_aspect_ratio
==
"pad"
:
image
=
expand2square
(
image
,
tuple
(
int
(
x
*
255
)
for
x
in
processor
.
image_processor
.
image_mean
),
)
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
][
0
]
elif
image_aspect_ratio
==
"anyres"
:
pixel_values
=
process_anyres_image
(
image
,
processor
.
image_processor
,
image_grid_pinpoints
)
else
:
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
][
0
]
pixel_values
=
pixel_values
.
astype
(
np
.
float16
)
return
pixel_values
,
image_hash
,
image
.
size
except
Exception
:
print
(
"Exception in TokenizerManager:
\n
"
+
get_exception_traceback
())
class
TokenizerManager
:
class
TokenizerManager
:
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -154,10 +109,11 @@ class TokenizerManager:
...
@@ -154,10 +109,11 @@ class TokenizerManager:
image_data
,
aspect_ratio
,
grid_pinpoints
,
self
.
processor
image_data
,
aspect_ratio
,
grid_pinpoints
,
self
.
processor
)
)
async
def
generate_request
(
self
,
obj
:
GenerateReqInput
):
async
def
generate_request
(
self
,
obj
:
GenerateReqInput
,
request
=
None
):
if
self
.
to_create_loop
:
if
self
.
to_create_loop
:
await
self
.
create_handle_loop
()
self
.
create_handle_loop
()
obj
.
post_init
()
is_single
=
obj
.
is_single
is_single
=
obj
.
is_single
if
is_single
:
if
is_single
:
rid
=
obj
.
rid
rid
=
obj
.
rid
...
@@ -170,7 +126,7 @@ class TokenizerManager:
...
@@ -170,7 +126,7 @@ class TokenizerManager:
if
len
(
input_ids
)
>=
self
.
context_len
:
if
len
(
input_ids
)
>=
self
.
context_len
:
raise
ValueError
(
raise
ValueError
(
f
"The input (
{
len
(
input_ids
)
}
tokens) is longer than the "
f
"The input (
{
len
(
input_ids
)
}
tokens) is longer than the "
f
"model's context length (
{
self
.
context_len
}
tokens)"
f
"model's context length (
{
self
.
context_len
}
tokens)
.
"
)
)
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
)
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
)
...
@@ -208,7 +164,14 @@ class TokenizerManager:
...
@@ -208,7 +164,14 @@ class TokenizerManager:
self
.
rid_to_state
[
rid
]
=
state
self
.
rid_to_state
[
rid
]
=
state
while
True
:
while
True
:
await
event
.
wait
()
try
:
await
asyncio
.
wait_for
(
event
.
wait
(),
timeout
=
5
)
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
out
=
self
.
convert_logprob_style
(
out
=
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
state
.
out_list
[
-
1
],
obj
.
return_logprob
,
obj
.
return_logprob
,
...
@@ -226,7 +189,8 @@ class TokenizerManager:
...
@@ -226,7 +189,8 @@ class TokenizerManager:
break
break
event
.
clear
()
event
.
clear
()
else
:
else
:
assert
obj
.
stream
is
False
if
obj
.
stream
:
raise
ValueError
(
"Do not support stream for batch mode."
)
if
obj
.
input_ids
is
None
:
if
obj
.
input_ids
is
None
:
bs
=
len
(
obj
.
text
)
bs
=
len
(
obj
.
text
)
...
@@ -276,7 +240,18 @@ class TokenizerManager:
...
@@ -276,7 +240,18 @@ class TokenizerManager:
for
i
in
range
(
bs
):
for
i
in
range
(
bs
):
rid
=
obj
.
rid
[
i
]
rid
=
obj
.
rid
[
i
]
state
=
self
.
rid_to_state
[
rid
]
state
=
self
.
rid_to_state
[
rid
]
await
state
.
event
.
wait
()
while
True
:
try
:
await
asyncio
.
wait_for
(
state
.
event
.
wait
(),
timeout
=
5
)
break
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
for
rid
in
obj
.
rid
:
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
output_list
.
append
(
output_list
.
append
(
self
.
convert_logprob_style
(
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
state
.
out_list
[
-
1
],
...
@@ -290,11 +265,16 @@ class TokenizerManager:
...
@@ -290,11 +265,16 @@ class TokenizerManager:
yield
output_list
yield
output_list
async
def
flush_cache
(
self
):
def
flush_cache
(
self
):
flush_cache_
req
=
FlushCacheReq
()
req
=
FlushCacheReq
()
self
.
send_to_router
.
send_pyobj
(
flush_cache_
req
)
self
.
send_to_router
.
send_pyobj
(
req
)
async
def
create_handle_loop
(
self
):
def
abort_request
(
self
,
rid
):
del
self
.
rid_to_state
[
rid
]
req
=
AbortReq
(
rid
)
self
.
send_to_router
.
send_pyobj
(
req
)
def
create_handle_loop
(
self
):
self
.
to_create_loop
=
False
self
.
to_create_loop
=
False
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
loop
.
create_task
(
self
.
handle_loop
())
loop
.
create_task
(
self
.
handle_loop
())
...
@@ -305,17 +285,20 @@ class TokenizerManager:
...
@@ -305,17 +285,20 @@ class TokenizerManager:
if
isinstance
(
recv_obj
,
BatchStrOut
):
if
isinstance
(
recv_obj
,
BatchStrOut
):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
if
state
is
None
:
continue
recv_obj
.
meta_info
[
i
][
"id"
]
=
rid
recv_obj
.
meta_info
[
i
][
"id"
]
=
rid
out_dict
=
{
out_dict
=
{
"text"
:
recv_obj
.
output_str
[
i
],
"text"
:
recv_obj
.
output_str
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
}
state
=
self
.
rid_to_state
[
rid
]
state
.
out_list
.
append
(
out_dict
)
state
.
out_list
.
append
(
out_dict
)
state
.
finished
=
recv_obj
.
finished
[
i
]
state
.
finished
=
recv_obj
.
finished
[
i
]
state
.
event
.
set
()
state
.
event
.
set
()
else
:
else
:
raise
ValueError
(
f
"Invalid object:
{
recv_obj
}
"
)
raise
ValueError
(
f
"Invalid object:
{
recv_obj
}
.
"
)
def
convert_logprob_style
(
def
convert_logprob_style
(
self
,
ret
,
return_logprob
,
top_logprobs_num
,
return_text_in_logprobs
self
,
ret
,
return_logprob
,
top_logprobs_num
,
return_text_in_logprobs
...
@@ -356,3 +339,50 @@ class TokenizerManager:
...
@@ -356,3 +339,50 @@ class TokenizerManager:
if
t
:
if
t
:
top_logprobs
[
i
]
=
self
.
detokenize_logprob_tokens
(
t
,
decode_to_text
)
top_logprobs
[
i
]
=
self
.
detokenize_logprob_tokens
(
t
,
decode_to_text
)
return
top_logprobs
return
top_logprobs
global
global_processor
def
init_global_processor
(
server_args
:
ServerArgs
):
global
global_processor
transformers
.
logging
.
set_verbosity_error
()
global_processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
def
get_pixel_values
(
image_data
,
image_aspect_ratio
=
None
,
image_grid_pinpoints
=
None
,
processor
=
None
):
try
:
processor
=
processor
or
global_processor
image
,
image_size
=
load_image
(
image_data
)
if
image_size
!=
None
:
image_hash
=
hash
(
image_data
)
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
]
for
_
in
range
(
len
(
pixel_values
)):
pixel_values
[
_
]
=
pixel_values
[
_
].
astype
(
np
.
float16
)
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
return
pixel_values
,
image_hash
,
image_size
else
:
image_hash
=
hash
(
image_data
)
if
image_aspect_ratio
==
"pad"
:
image
=
expand2square
(
image
,
tuple
(
int
(
x
*
255
)
for
x
in
processor
.
image_processor
.
image_mean
),
)
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
][
0
]
elif
image_aspect_ratio
==
"anyres"
:
pixel_values
=
process_anyres_image
(
image
,
processor
.
image_processor
,
image_grid_pinpoints
)
else
:
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
][
0
]
pixel_values
=
pixel_values
.
astype
(
np
.
float16
)
return
pixel_values
,
image_hash
,
image
.
size
except
Exception
:
print
(
"Exception in TokenizerManager:
\n
"
+
get_exception_traceback
())
\ No newline at end of file
python/sglang/srt/openai_api_adapter.py
View file @
8210ec60
...
@@ -335,7 +335,7 @@ def to_openai_style_logprobs(
...
@@ -335,7 +335,7 @@ def to_openai_style_logprobs(
ret_logprobs
.
tokens
.
append
(
token_text
)
ret_logprobs
.
tokens
.
append
(
token_text
)
ret_logprobs
.
token_logprobs
.
append
(
logprob
)
ret_logprobs
.
token_logprobs
.
append
(
logprob
)
# Not
S
upported yet
# Not
s
upported yet
ret_logprobs
.
text_offset
.
append
(
-
1
)
ret_logprobs
.
text_offset
.
append
(
-
1
)
def
append_top_logprobs
(
top_logprobs
):
def
append_top_logprobs
(
top_logprobs
):
...
...
python/sglang/srt/server.py
View file @
8210ec60
...
@@ -10,6 +10,7 @@ import sys
...
@@ -10,6 +10,7 @@ import sys
import
threading
import
threading
import
time
import
time
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
from
http
import
HTTPStatus
# Fix a bug of Python threading
# Fix a bug of Python threading
setattr
(
threading
,
"_register_atexit"
,
lambda
*
args
,
**
kwargs
:
None
)
setattr
(
threading
,
"_register_atexit"
,
lambda
*
args
,
**
kwargs
:
None
)
...
@@ -73,7 +74,7 @@ async def get_server_args():
...
@@ -73,7 +74,7 @@ async def get_server_args():
@
app
.
get
(
"/flush_cache"
)
@
app
.
get
(
"/flush_cache"
)
async
def
flush_cache
():
async
def
flush_cache
():
await
tokenizer_manager
.
flush_cache
()
tokenizer_manager
.
flush_cache
()
return
Response
(
return
Response
(
content
=
"Cache flushed.
\n
Please check backend logs for more details. "
content
=
"Cache flushed.
\n
Please check backend logs for more details. "
"(When there are running or waiting requests, the operation will not be performed.)
\n
"
,
"(When there are running or waiting requests, the operation will not be performed.)
\n
"
,
...
@@ -81,24 +82,25 @@ async def flush_cache():
...
@@ -81,24 +82,25 @@ async def flush_cache():
)
)
async
def
generate_request
(
obj
:
GenerateReqInput
):
async
def
generate_request
(
obj
:
GenerateReqInput
,
request
:
Request
):
obj
.
post_init
()
if
obj
.
stream
:
if
obj
.
stream
:
async
def
stream_results
():
async
def
stream_results
():
async
for
out
in
tokenizer_manager
.
generate_request
(
obj
):
try
:
async
for
out
in
tokenizer_manager
.
generate_request
(
obj
,
request
):
yield
f
"data:
{
json
.
dumps
(
out
,
ensure_ascii
=
False
)
}
\n\n
"
except
ValueError
as
e
:
out
=
{
"error"
:
{
"message"
:
str
(
e
)}}
yield
f
"data:
{
json
.
dumps
(
out
,
ensure_ascii
=
False
)
}
\n\n
"
yield
f
"data:
{
json
.
dumps
(
out
,
ensure_ascii
=
False
)
}
\n\n
"
yield
"data: [DONE]
\n\n
"
yield
"data: [DONE]
\n\n
"
return
StreamingResponse
(
stream_results
(),
media_type
=
"text/event-stream"
)
return
StreamingResponse
(
stream_results
(),
media_type
=
"text/event-stream"
)
else
:
try
:
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
obj
).
__anext__
()
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
return
ret
return
ret
except
ValueError
as
e
:
except
ValueError
as
e
:
print
(
f
"E
rror:
{
e
}
"
)
return
JSONResponse
({
"e
rror
"
:
{
"message"
:
str
(
e
)}},
return
JSONResponse
({
"error"
:
str
(
e
)},
status_code
=
400
)
status_code
=
HTTPStatus
.
BAD_REQUEST
)
app
.
post
(
"/generate"
)(
generate_request
)
app
.
post
(
"/generate"
)(
generate_request
)
app
.
put
(
"/generate"
)(
generate_request
)
app
.
put
(
"/generate"
)(
generate_request
)
...
@@ -186,6 +188,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -186,6 +188,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if
server_args
.
api_key
and
server_args
.
api_key
!=
""
:
if
server_args
.
api_key
and
server_args
.
api_key
!=
""
:
app
.
add_middleware
(
APIKeyValidatorMiddleware
,
api_key
=
server_args
.
api_key
)
app
.
add_middleware
(
APIKeyValidatorMiddleware
,
api_key
=
server_args
.
api_key
)
# Send a warmup request
def
_wait_and_warmup
():
def
_wait_and_warmup
():
headers
=
{}
headers
=
{}
url
=
server_args
.
url
()
url
=
server_args
.
url
()
...
@@ -228,6 +231,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -228,6 +231,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
t
=
threading
.
Thread
(
target
=
_wait_and_warmup
)
t
=
threading
.
Thread
(
target
=
_wait_and_warmup
)
t
.
start
()
t
.
start
()
# Listen for requests
try
:
try
:
uvicorn
.
run
(
uvicorn
.
run
(
app
,
app
,
...
...
python/sglang/test/test_utils.py
View file @
8210ec60
...
@@ -9,7 +9,7 @@ import requests
...
@@ -9,7 +9,7 @@ import requests
from
sglang.backend.openai
import
OpenAI
from
sglang.backend.openai
import
OpenAI
from
sglang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.
srt.
utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
def
call_generate_lightllm
(
prompt
,
temperature
,
max_tokens
,
stop
=
None
,
url
=
None
):
def
call_generate_lightllm
(
prompt
,
temperature
,
max_tokens
,
stop
=
None
,
url
=
None
):
...
...
python/sglang/utils.py
View file @
8210ec60
...
@@ -93,8 +93,12 @@ def http_request(
...
@@ -93,8 +93,12 @@ def http_request(
data
=
None
data
=
None
else
:
else
:
data
=
bytes
(
dumps
(
json
),
encoding
=
"utf-8"
)
data
=
bytes
(
dumps
(
json
),
encoding
=
"utf-8"
)
resp
=
urllib
.
request
.
urlopen
(
req
,
data
=
data
,
cafile
=
verify
)
return
HttpResponse
(
resp
)
try
:
resp
=
urllib
.
request
.
urlopen
(
req
,
data
=
data
,
cafile
=
verify
)
return
HttpResponse
(
resp
)
except
urllib
.
error
.
HTTPError
as
e
:
return
HttpResponse
(
e
)
def
encode_image_base64
(
image_path
):
def
encode_image_base64
(
image_path
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment