Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8210ec60
"docs/source/vscode:/vscode.git/clone" did not exist on "bb7b3c6f25097c265c9f6f0504aea2ae6146a4e1"
Unverified
Commit
8210ec60
authored
May 17, 2024
by
Lianmin Zheng
Committed by
GitHub
May 17, 2024
Browse files
Improve error handling & abort disconnected requests (#449)
parent
5be9eb8a
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
198 additions
and
126 deletions
+198
-126
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+18
-10
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+7
-3
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+10
-7
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+49
-31
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+88
-58
python/sglang/srt/openai_api_adapter.py
python/sglang/srt/openai_api_adapter.py
+1
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+18
-13
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-1
python/sglang/utils.py
python/sglang/utils.py
+6
-2
No files found.
python/sglang/backend/runtime_endpoint.py
View file @
8210ec60
...
@@ -34,7 +34,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -34,7 +34,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
self
.
model_info
=
res
.
json
()
self
.
model_info
=
res
.
json
()
self
.
chat_template
=
get_chat_template_by_model_path
(
self
.
chat_template
=
get_chat_template_by_model_path
(
...
@@ -50,7 +50,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -50,7 +50,7 @@ class RuntimeEndpoint(BaseBackend):
auth_token
=
self
.
auth_token
,
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
return
res
.
status_code
==
200
self
.
_assert_success
(
res
)
def
get_server_args
(
self
):
def
get_server_args
(
self
):
res
=
http_request
(
res
=
http_request
(
...
@@ -58,6 +58,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -58,6 +58,7 @@ class RuntimeEndpoint(BaseBackend):
auth_token
=
self
.
auth_token
,
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
self
.
_assert_success
(
res
)
return
res
.
json
()
return
res
.
json
()
def
get_chat_template
(
self
):
def
get_chat_template
(
self
):
...
@@ -71,7 +72,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -71,7 +72,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
def
commit_lazy_operations
(
self
,
s
:
StreamExecutor
):
def
commit_lazy_operations
(
self
,
s
:
StreamExecutor
):
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
...
@@ -83,7 +84,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -83,7 +84,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
def
fill_image
(
self
,
s
:
StreamExecutor
):
def
fill_image
(
self
,
s
:
StreamExecutor
):
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
...
@@ -95,7 +96,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -95,7 +96,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
def
generate
(
def
generate
(
self
,
self
,
...
@@ -133,6 +134,8 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -133,6 +134,8 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
self
.
_assert_success
(
res
)
obj
=
res
.
json
()
obj
=
res
.
json
()
comp
=
obj
[
"text"
]
comp
=
obj
[
"text"
]
return
comp
,
obj
[
"meta_info"
]
return
comp
,
obj
[
"meta_info"
]
...
@@ -167,7 +170,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -167,7 +170,7 @@ class RuntimeEndpoint(BaseBackend):
data
[
"stream"
]
=
True
data
[
"stream"
]
=
True
self
.
_add_images
(
s
,
data
)
self
.
_add_images
(
s
,
data
)
res
ponse
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
data
,
json
=
data
,
stream
=
True
,
stream
=
True
,
...
@@ -175,10 +178,11 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -175,10 +178,11 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
self
.
_assert_success
(
res
)
pos
=
0
pos
=
0
incomplete_text
=
""
incomplete_text
=
""
for
chunk
in
res
ponse
.
iter_lines
(
decode_unicode
=
False
):
for
chunk
in
res
.
iter_lines
(
decode_unicode
=
False
):
chunk
=
chunk
.
decode
(
"utf-8"
)
chunk
=
chunk
.
decode
(
"utf-8"
)
if
chunk
and
chunk
.
startswith
(
"data:"
):
if
chunk
and
chunk
.
startswith
(
"data:"
):
if
chunk
==
"data: [DONE]"
:
if
chunk
==
"data: [DONE]"
:
...
@@ -211,7 +215,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -211,7 +215,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
prompt_len
=
res
.
json
()[
"meta_info"
][
"prompt_tokens"
]
prompt_len
=
res
.
json
()[
"meta_info"
][
"prompt_tokens"
]
# Compute logprob
# Compute logprob
...
@@ -229,7 +233,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -229,7 +233,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
obj
=
res
.
json
()
obj
=
res
.
json
()
normalized_prompt_logprobs
=
[
normalized_prompt_logprobs
=
[
r
[
"meta_info"
][
"normalized_prompt_logprob"
]
for
r
in
obj
r
[
"meta_info"
][
"normalized_prompt_logprob"
]
for
r
in
obj
...
@@ -253,9 +257,13 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -253,9 +257,13 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
def
_add_images
(
self
,
s
:
StreamExecutor
,
data
):
def
_add_images
(
self
,
s
:
StreamExecutor
,
data
):
if
s
.
images_
:
if
s
.
images_
:
assert
len
(
s
.
images_
)
==
1
,
"Only support one image."
assert
len
(
s
.
images_
)
==
1
,
"Only support one image."
data
[
"image_data"
]
=
s
.
images_
[
0
][
1
]
data
[
"image_data"
]
=
s
.
images_
[
0
][
1
]
def
_assert_success
(
self
,
res
):
if
res
.
status_code
!=
200
:
raise
RuntimeError
(
res
.
json
())
\ No newline at end of file
python/sglang/lang/interpreter.py
View file @
8210ec60
...
@@ -191,7 +191,7 @@ class StreamExecutor:
...
@@ -191,7 +191,7 @@ class StreamExecutor:
self
.
variable_event
=
{}
# Dict[name: str -> event: threading.Event]
self
.
variable_event
=
{}
# Dict[name: str -> event: threading.Event]
self
.
meta_info
=
{}
# Dict[name: str -> info: str]
self
.
meta_info
=
{}
# Dict[name: str -> info: str]
self
.
is_finished
=
False
self
.
is_finished
=
False
self
.
error
=
None
self
.
error
_
=
None
# For completion
# For completion
self
.
text_
=
""
# The full text
self
.
text_
=
""
# The full text
...
@@ -300,6 +300,10 @@ class StreamExecutor:
...
@@ -300,6 +300,10 @@ class StreamExecutor:
self
.
sync
()
self
.
sync
()
return
self
.
messages_
return
self
.
messages_
def
error
(
self
):
self
.
sync
()
return
self
.
error_
def
end
(
self
):
def
end
(
self
):
if
self
.
use_thread
:
if
self
.
use_thread
:
if
self
.
worker
.
is_alive
():
if
self
.
worker
.
is_alive
():
...
@@ -338,7 +342,7 @@ class StreamExecutor:
...
@@ -338,7 +342,7 @@ class StreamExecutor:
if
self
.
stream_var_event
:
if
self
.
stream_var_event
:
for
name
in
self
.
stream_var_event
:
for
name
in
self
.
stream_var_event
:
self
.
stream_var_event
[
name
].
set
()
self
.
stream_var_event
[
name
].
set
()
self
.
error
=
error
self
.
error
_
=
error
if
self
.
stream_text_event
:
if
self
.
stream_text_event
:
self
.
stream_text_event
.
set
()
self
.
stream_text_event
.
set
()
...
@@ -713,7 +717,7 @@ class ProgramState:
...
@@ -713,7 +717,7 @@ class ProgramState:
return
self
.
stream_executor
.
sync
()
return
self
.
stream_executor
.
sync
()
def
error
(
self
):
def
error
(
self
):
return
self
.
stream_executor
.
error
return
self
.
stream_executor
.
error
()
def
text_iter
(
self
,
var_name
:
Optional
[
str
]
=
None
):
def
text_iter
(
self
,
var_name
:
Optional
[
str
]
=
None
):
if
self
.
stream_executor
.
stream
:
if
self
.
stream_executor
.
stream
:
...
...
python/sglang/srt/managers/io_struct.py
View file @
8210ec60
...
@@ -31,12 +31,9 @@ class GenerateReqInput:
...
@@ -31,12 +31,9 @@ class GenerateReqInput:
def
post_init
(
self
):
def
post_init
(
self
):
if
self
.
text
is
None
:
if
((
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
assert
(
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
)):
self
.
input_ids
is
not
None
raise
ValueError
(
"Either text or input_ids should be provided."
)
),
"Either text or input_ids should be provided"
else
:
assert
self
.
input_ids
is
None
,
"Either text or input_ids should be provided"
if
self
.
text
is
not
None
:
if
self
.
text
is
not
None
:
is_single
=
isinstance
(
self
.
text
,
str
)
is_single
=
isinstance
(
self
.
text
,
str
)
...
@@ -71,7 +68,8 @@ class GenerateReqInput:
...
@@ -71,7 +68,8 @@ class GenerateReqInput:
if
self
.
rid
is
None
:
if
self
.
rid
is
None
:
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
num
)]
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
num
)]
else
:
else
:
assert
isinstance
(
self
.
rid
,
list
)
if
not
isinstance
(
self
.
rid
,
list
):
raise
ValueError
(
"The rid should be a list."
)
if
self
.
return_logprob
is
None
:
if
self
.
return_logprob
is
None
:
self
.
return_logprob
=
[
False
]
*
num
self
.
return_logprob
=
[
False
]
*
num
...
@@ -129,6 +127,11 @@ class FlushCacheReq:
...
@@ -129,6 +127,11 @@ class FlushCacheReq:
pass
pass
@
dataclass
class
AbortReq
:
rid
:
str
@
dataclass
@
dataclass
class
DetokenizeReqInput
:
class
DetokenizeReqInput
:
input_ids
:
List
[
int
]
input_ids
:
List
[
int
]
python/sglang/srt/managers/router/model_rpc.py
View file @
8210ec60
...
@@ -20,6 +20,7 @@ from sglang.srt.constrained.fsm_cache import FSMCache
...
@@ -20,6 +20,7 @@ from sglang.srt.constrained.fsm_cache import FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchTokenIDOut
,
BatchTokenIDOut
,
FlushCacheReq
,
FlushCacheReq
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
...
@@ -110,6 +111,8 @@ class ModelRpcServer:
...
@@ -110,6 +111,8 @@ class ModelRpcServer:
get_int_token_logit_bias
(
self
.
tokenizer
,
self
.
model_config
.
vocab_size
)
get_int_token_logit_bias
(
self
.
tokenizer
,
self
.
model_config
.
vocab_size
)
)
)
set_random_seed
(
server_args
.
random_seed
)
set_random_seed
(
server_args
.
random_seed
)
# Print info
logger
.
info
(
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: "
f
"Rank
{
self
.
tp_rank
}
: "
f
"max_total_num_token=
{
self
.
max_total_num_token
}
, "
f
"max_total_num_token=
{
self
.
max_total_num_token
}
, "
...
@@ -160,24 +163,6 @@ class ModelRpcServer:
...
@@ -160,24 +163,6 @@ class ModelRpcServer:
self
.
min_new_token_ratio
=
min
(
0.2
*
server_args
.
schedule_conservativeness
,
1.0
)
self
.
min_new_token_ratio
=
min
(
0.2
*
server_args
.
schedule_conservativeness
,
1.0
)
self
.
new_token_ratio_step
=
(
0.0001
,
0.05
)
# (down, up)
self
.
new_token_ratio_step
=
(
0.0001
,
0.05
)
# (down, up)
def
flush_cache
(
self
):
if
len
(
self
.
forward_queue
)
==
0
and
(
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
):
self
.
tree_cache
.
reset
()
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
regex_fsm_cache
.
reset
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
logger
.
info
(
"Cache flushed successfully!"
)
else
:
warnings
.
warn
(
f
"Cache not flushed because there are pending requests. "
f
"#queue-req:
{
len
(
self
.
forward_queue
)
}
, "
f
"#running-req:
{
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
}
"
)
def
exposed_step
(
self
,
recv_reqs
):
def
exposed_step
(
self
,
recv_reqs
):
if
self
.
tp_size
!=
1
:
if
self
.
tp_size
!=
1
:
recv_reqs
=
obtain
(
recv_reqs
)
recv_reqs
=
obtain
(
recv_reqs
)
...
@@ -189,6 +174,8 @@ class ModelRpcServer:
...
@@ -189,6 +174,8 @@ class ModelRpcServer:
self
.
handle_generate_request
(
recv_req
)
self
.
handle_generate_request
(
recv_req
)
elif
isinstance
(
recv_req
,
FlushCacheReq
):
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
self
.
flush_cache
()
elif
isinstance
(
recv_req
,
AbortReq
):
self
.
abort_request
(
recv_req
)
else
:
else
:
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
...
@@ -207,9 +194,8 @@ class ModelRpcServer:
...
@@ -207,9 +194,8 @@ class ModelRpcServer:
new_batch
=
self
.
get_new_fill_batch
()
new_batch
=
self
.
get_new_fill_batch
()
if
new_batch
is
not
None
:
if
new_batch
is
not
None
:
# Run new fill batch
# Run
a
new fill batch
self
.
forward_fill_batch
(
new_batch
)
self
.
forward_fill_batch
(
new_batch
)
self
.
cache_filled_batch
(
new_batch
)
self
.
cache_filled_batch
(
new_batch
)
if
not
new_batch
.
is_empty
():
if
not
new_batch
.
is_empty
():
...
@@ -225,14 +211,8 @@ class ModelRpcServer:
...
@@ -225,14 +211,8 @@ class ModelRpcServer:
self
.
num_generated_tokens
+=
len
(
self
.
running_batch
.
reqs
)
self
.
num_generated_tokens
+=
len
(
self
.
running_batch
.
reqs
)
self
.
forward_decode_batch
(
self
.
running_batch
)
self
.
forward_decode_batch
(
self
.
running_batch
)
if
self
.
running_batch
.
is_empty
():
# Print stats
self
.
running_batch
=
None
if
self
.
tp_rank
==
0
:
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
reqs
[
0
].
stream
:
break
if
self
.
running_batch
is
not
None
and
self
.
tp_rank
==
0
:
if
self
.
decode_forward_ct
%
40
==
0
:
if
self
.
decode_forward_ct
%
40
==
0
:
num_used
=
self
.
max_total_num_token
-
(
num_used
=
self
.
max_total_num_token
-
(
self
.
token_to_kv_pool
.
available_size
()
self
.
token_to_kv_pool
.
available_size
()
...
@@ -250,8 +230,15 @@ class ModelRpcServer:
...
@@ -250,8 +230,15 @@ class ModelRpcServer:
f
"gen throughput (token/s):
{
throuhgput
:.
2
f
}
, "
f
"gen throughput (token/s):
{
throuhgput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
forward_queue
)
}
"
f
"#queue-req:
{
len
(
self
.
forward_queue
)
}
"
)
)
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
reqs
[
0
].
stream
:
break
else
:
else
:
#
c
heck the available size
#
C
heck the available size
available_size
=
(
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
+
self
.
tree_cache
.
evictable_size
()
...
@@ -295,7 +282,7 @@ class ModelRpcServer:
...
@@ -295,7 +282,7 @@ class ModelRpcServer:
req
.
sampling_params
.
regex
req
.
sampling_params
.
regex
)
)
# Truncate
long
prompts
# Truncate prompts
that are too long
req
.
input_ids
=
req
.
input_ids
[:
self
.
model_config
.
context_len
-
1
]
req
.
input_ids
=
req
.
input_ids
[:
self
.
model_config
.
context_len
-
1
]
req
.
sampling_params
.
max_new_tokens
=
min
(
req
.
sampling_params
.
max_new_tokens
=
min
(
req
.
sampling_params
.
max_new_tokens
,
req
.
sampling_params
.
max_new_tokens
,
...
@@ -311,6 +298,7 @@ class ModelRpcServer:
...
@@ -311,6 +298,7 @@ class ModelRpcServer:
):
):
return
None
return
None
# Compute matched prefix length
for
req
in
self
.
forward_queue
:
for
req
in
self
.
forward_queue
:
prefix_indices
,
last_node
=
self
.
tree_cache
.
match_prefix
(
req
.
input_ids
)
prefix_indices
,
last_node
=
self
.
tree_cache
.
match_prefix
(
req
.
input_ids
)
if
req
.
return_logprob
:
if
req
.
return_logprob
:
...
@@ -383,6 +371,7 @@ class ModelRpcServer:
...
@@ -383,6 +371,7 @@ class ModelRpcServer:
if
len
(
can_run_list
)
==
0
:
if
len
(
can_run_list
)
==
0
:
return
None
return
None
# Print stats
if
self
.
tp_rank
==
0
:
if
self
.
tp_rank
==
0
:
running_req
=
(
running_req
=
(
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
...
@@ -410,6 +399,7 @@ class ModelRpcServer:
...
@@ -410,6 +399,7 @@ class ModelRpcServer:
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
# )
# )
# Return the new batch
new_batch
=
Batch
.
init_new
(
new_batch
=
Batch
.
init_new
(
can_run_list
,
can_run_list
,
self
.
req_to_token_pool
,
self
.
req_to_token_pool
,
...
@@ -487,7 +477,7 @@ class ModelRpcServer:
...
@@ -487,7 +477,7 @@ class ModelRpcServer:
self
.
handle_finished_requests
(
batch
)
self
.
handle_finished_requests
(
batch
)
def
cache_filled_batch
(
self
,
batch
:
Batch
):
def
cache_filled_batch
(
self
,
batch
:
Batch
):
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
cpu
().
tolist
()
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
cpu
().
numpy
()
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
new_prefix_indices
,
new_last_node
=
self
.
tree_cache
.
cache_req
(
new_prefix_indices
,
new_last_node
=
self
.
tree_cache
.
cache_req
(
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
],
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
],
...
@@ -671,6 +661,34 @@ class ModelRpcServer:
...
@@ -671,6 +661,34 @@ class ModelRpcServer:
else
:
else
:
batch
.
reqs
=
[]
batch
.
reqs
=
[]
def
flush_cache
(
self
):
if
len
(
self
.
forward_queue
)
==
0
and
(
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
):
self
.
tree_cache
.
reset
()
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
regex_fsm_cache
.
reset
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
logger
.
info
(
"Cache flushed successfully!"
)
else
:
warnings
.
warn
(
f
"Cache not flushed because there are pending requests. "
f
"#queue-req:
{
len
(
self
.
forward_queue
)
}
, "
f
"#running-req:
{
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
}
"
)
def
abort_request
(
self
,
recv_req
):
to_del
=
None
for
i
,
req
in
enumerate
(
self
.
forward_queue
):
if
req
.
rid
==
recv_req
.
rid
:
to_del
=
i
break
if
to_del
is
not
None
:
del
self
.
forward_queue
[
to_del
]
class
ModelRpcService
(
rpyc
.
Service
):
class
ModelRpcService
(
rpyc
.
Service
):
exposed_ModelRpcServer
=
ModelRpcServer
exposed_ModelRpcServer
=
ModelRpcServer
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
8210ec60
...
@@ -19,6 +19,7 @@ from sglang.srt.hf_transformers_utils import (
...
@@ -19,6 +19,7 @@ from sglang.srt.hf_transformers_utils import (
get_tokenizer
,
get_tokenizer
,
)
)
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchStrOut
,
BatchStrOut
,
FlushCacheReq
,
FlushCacheReq
,
GenerateReqInput
,
GenerateReqInput
,
...
@@ -42,52 +43,6 @@ class ReqState:
...
@@ -42,52 +43,6 @@ class ReqState:
event
:
asyncio
.
Event
event
:
asyncio
.
Event
global
global_processor
def
init_global_processor
(
server_args
:
ServerArgs
):
global
global_processor
transformers
.
logging
.
set_verbosity_error
()
global_processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
def
get_pixel_values
(
image_data
,
image_aspect_ratio
=
None
,
image_grid_pinpoints
=
None
,
processor
=
None
):
try
:
processor
=
processor
or
global_processor
image
,
image_size
=
load_image
(
image_data
)
if
image_size
!=
None
:
image_hash
=
hash
(
image_data
)
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
]
for
_
in
range
(
len
(
pixel_values
)):
pixel_values
[
_
]
=
pixel_values
[
_
].
astype
(
np
.
float16
)
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
return
pixel_values
,
image_hash
,
image_size
else
:
image_hash
=
hash
(
image_data
)
if
image_aspect_ratio
==
"pad"
:
image
=
expand2square
(
image
,
tuple
(
int
(
x
*
255
)
for
x
in
processor
.
image_processor
.
image_mean
),
)
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
][
0
]
elif
image_aspect_ratio
==
"anyres"
:
pixel_values
=
process_anyres_image
(
image
,
processor
.
image_processor
,
image_grid_pinpoints
)
else
:
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
][
0
]
pixel_values
=
pixel_values
.
astype
(
np
.
float16
)
return
pixel_values
,
image_hash
,
image
.
size
except
Exception
:
print
(
"Exception in TokenizerManager:
\n
"
+
get_exception_traceback
())
class
TokenizerManager
:
class
TokenizerManager
:
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -154,10 +109,11 @@ class TokenizerManager:
...
@@ -154,10 +109,11 @@ class TokenizerManager:
image_data
,
aspect_ratio
,
grid_pinpoints
,
self
.
processor
image_data
,
aspect_ratio
,
grid_pinpoints
,
self
.
processor
)
)
async
def
generate_request
(
self
,
obj
:
GenerateReqInput
):
async
def
generate_request
(
self
,
obj
:
GenerateReqInput
,
request
=
None
):
if
self
.
to_create_loop
:
if
self
.
to_create_loop
:
await
self
.
create_handle_loop
()
self
.
create_handle_loop
()
obj
.
post_init
()
is_single
=
obj
.
is_single
is_single
=
obj
.
is_single
if
is_single
:
if
is_single
:
rid
=
obj
.
rid
rid
=
obj
.
rid
...
@@ -170,7 +126,7 @@ class TokenizerManager:
...
@@ -170,7 +126,7 @@ class TokenizerManager:
if
len
(
input_ids
)
>=
self
.
context_len
:
if
len
(
input_ids
)
>=
self
.
context_len
:
raise
ValueError
(
raise
ValueError
(
f
"The input (
{
len
(
input_ids
)
}
tokens) is longer than the "
f
"The input (
{
len
(
input_ids
)
}
tokens) is longer than the "
f
"model's context length (
{
self
.
context_len
}
tokens)"
f
"model's context length (
{
self
.
context_len
}
tokens)
.
"
)
)
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
)
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
)
...
@@ -208,7 +164,14 @@ class TokenizerManager:
...
@@ -208,7 +164,14 @@ class TokenizerManager:
self
.
rid_to_state
[
rid
]
=
state
self
.
rid_to_state
[
rid
]
=
state
while
True
:
while
True
:
await
event
.
wait
()
try
:
await
asyncio
.
wait_for
(
event
.
wait
(),
timeout
=
5
)
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
out
=
self
.
convert_logprob_style
(
out
=
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
state
.
out_list
[
-
1
],
obj
.
return_logprob
,
obj
.
return_logprob
,
...
@@ -226,7 +189,8 @@ class TokenizerManager:
...
@@ -226,7 +189,8 @@ class TokenizerManager:
break
break
event
.
clear
()
event
.
clear
()
else
:
else
:
assert
obj
.
stream
is
False
if
obj
.
stream
:
raise
ValueError
(
"Do not support stream for batch mode."
)
if
obj
.
input_ids
is
None
:
if
obj
.
input_ids
is
None
:
bs
=
len
(
obj
.
text
)
bs
=
len
(
obj
.
text
)
...
@@ -276,7 +240,18 @@ class TokenizerManager:
...
@@ -276,7 +240,18 @@ class TokenizerManager:
for
i
in
range
(
bs
):
for
i
in
range
(
bs
):
rid
=
obj
.
rid
[
i
]
rid
=
obj
.
rid
[
i
]
state
=
self
.
rid_to_state
[
rid
]
state
=
self
.
rid_to_state
[
rid
]
await
state
.
event
.
wait
()
while
True
:
try
:
await
asyncio
.
wait_for
(
state
.
event
.
wait
(),
timeout
=
5
)
break
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
for
rid
in
obj
.
rid
:
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
output_list
.
append
(
output_list
.
append
(
self
.
convert_logprob_style
(
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
state
.
out_list
[
-
1
],
...
@@ -290,11 +265,16 @@ class TokenizerManager:
...
@@ -290,11 +265,16 @@ class TokenizerManager:
yield
output_list
yield
output_list
async
def
flush_cache
(
self
):
def
flush_cache
(
self
):
flush_cache_
req
=
FlushCacheReq
()
req
=
FlushCacheReq
()
self
.
send_to_router
.
send_pyobj
(
flush_cache_
req
)
self
.
send_to_router
.
send_pyobj
(
req
)
async
def
create_handle_loop
(
self
):
def
abort_request
(
self
,
rid
):
del
self
.
rid_to_state
[
rid
]
req
=
AbortReq
(
rid
)
self
.
send_to_router
.
send_pyobj
(
req
)
def
create_handle_loop
(
self
):
self
.
to_create_loop
=
False
self
.
to_create_loop
=
False
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
loop
.
create_task
(
self
.
handle_loop
())
loop
.
create_task
(
self
.
handle_loop
())
...
@@ -305,17 +285,20 @@ class TokenizerManager:
...
@@ -305,17 +285,20 @@ class TokenizerManager:
if
isinstance
(
recv_obj
,
BatchStrOut
):
if
isinstance
(
recv_obj
,
BatchStrOut
):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
if
state
is
None
:
continue
recv_obj
.
meta_info
[
i
][
"id"
]
=
rid
recv_obj
.
meta_info
[
i
][
"id"
]
=
rid
out_dict
=
{
out_dict
=
{
"text"
:
recv_obj
.
output_str
[
i
],
"text"
:
recv_obj
.
output_str
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
}
state
=
self
.
rid_to_state
[
rid
]
state
.
out_list
.
append
(
out_dict
)
state
.
out_list
.
append
(
out_dict
)
state
.
finished
=
recv_obj
.
finished
[
i
]
state
.
finished
=
recv_obj
.
finished
[
i
]
state
.
event
.
set
()
state
.
event
.
set
()
else
:
else
:
raise
ValueError
(
f
"Invalid object:
{
recv_obj
}
"
)
raise
ValueError
(
f
"Invalid object:
{
recv_obj
}
.
"
)
def
convert_logprob_style
(
def
convert_logprob_style
(
self
,
ret
,
return_logprob
,
top_logprobs_num
,
return_text_in_logprobs
self
,
ret
,
return_logprob
,
top_logprobs_num
,
return_text_in_logprobs
...
@@ -356,3 +339,50 @@ class TokenizerManager:
...
@@ -356,3 +339,50 @@ class TokenizerManager:
if
t
:
if
t
:
top_logprobs
[
i
]
=
self
.
detokenize_logprob_tokens
(
t
,
decode_to_text
)
top_logprobs
[
i
]
=
self
.
detokenize_logprob_tokens
(
t
,
decode_to_text
)
return
top_logprobs
return
top_logprobs
global
global_processor
def
init_global_processor
(
server_args
:
ServerArgs
):
global
global_processor
transformers
.
logging
.
set_verbosity_error
()
global_processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
def
get_pixel_values
(
image_data
,
image_aspect_ratio
=
None
,
image_grid_pinpoints
=
None
,
processor
=
None
):
try
:
processor
=
processor
or
global_processor
image
,
image_size
=
load_image
(
image_data
)
if
image_size
!=
None
:
image_hash
=
hash
(
image_data
)
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
]
for
_
in
range
(
len
(
pixel_values
)):
pixel_values
[
_
]
=
pixel_values
[
_
].
astype
(
np
.
float16
)
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
return
pixel_values
,
image_hash
,
image_size
else
:
image_hash
=
hash
(
image_data
)
if
image_aspect_ratio
==
"pad"
:
image
=
expand2square
(
image
,
tuple
(
int
(
x
*
255
)
for
x
in
processor
.
image_processor
.
image_mean
),
)
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
][
0
]
elif
image_aspect_ratio
==
"anyres"
:
pixel_values
=
process_anyres_image
(
image
,
processor
.
image_processor
,
image_grid_pinpoints
)
else
:
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
][
0
]
pixel_values
=
pixel_values
.
astype
(
np
.
float16
)
return
pixel_values
,
image_hash
,
image
.
size
except
Exception
:
print
(
"Exception in TokenizerManager:
\n
"
+
get_exception_traceback
())
\ No newline at end of file
python/sglang/srt/openai_api_adapter.py
View file @
8210ec60
...
@@ -335,7 +335,7 @@ def to_openai_style_logprobs(
...
@@ -335,7 +335,7 @@ def to_openai_style_logprobs(
ret_logprobs
.
tokens
.
append
(
token_text
)
ret_logprobs
.
tokens
.
append
(
token_text
)
ret_logprobs
.
token_logprobs
.
append
(
logprob
)
ret_logprobs
.
token_logprobs
.
append
(
logprob
)
# Not
S
upported yet
# Not
s
upported yet
ret_logprobs
.
text_offset
.
append
(
-
1
)
ret_logprobs
.
text_offset
.
append
(
-
1
)
def
append_top_logprobs
(
top_logprobs
):
def
append_top_logprobs
(
top_logprobs
):
...
...
python/sglang/srt/server.py
View file @
8210ec60
...
@@ -10,6 +10,7 @@ import sys
...
@@ -10,6 +10,7 @@ import sys
import
threading
import
threading
import
time
import
time
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
from
http
import
HTTPStatus
# Fix a bug of Python threading
# Fix a bug of Python threading
setattr
(
threading
,
"_register_atexit"
,
lambda
*
args
,
**
kwargs
:
None
)
setattr
(
threading
,
"_register_atexit"
,
lambda
*
args
,
**
kwargs
:
None
)
...
@@ -73,7 +74,7 @@ async def get_server_args():
...
@@ -73,7 +74,7 @@ async def get_server_args():
@
app
.
get
(
"/flush_cache"
)
@
app
.
get
(
"/flush_cache"
)
async
def
flush_cache
():
async
def
flush_cache
():
await
tokenizer_manager
.
flush_cache
()
tokenizer_manager
.
flush_cache
()
return
Response
(
return
Response
(
content
=
"Cache flushed.
\n
Please check backend logs for more details. "
content
=
"Cache flushed.
\n
Please check backend logs for more details. "
"(When there are running or waiting requests, the operation will not be performed.)
\n
"
,
"(When there are running or waiting requests, the operation will not be performed.)
\n
"
,
...
@@ -81,24 +82,25 @@ async def flush_cache():
...
@@ -81,24 +82,25 @@ async def flush_cache():
)
)
async
def
generate_request
(
obj
:
GenerateReqInput
):
async
def
generate_request
(
obj
:
GenerateReqInput
,
request
:
Request
):
obj
.
post_init
()
if
obj
.
stream
:
if
obj
.
stream
:
async
def
stream_results
():
async
def
stream_results
():
async
for
out
in
tokenizer_manager
.
generate_request
(
obj
):
try
:
async
for
out
in
tokenizer_manager
.
generate_request
(
obj
,
request
):
yield
f
"data:
{
json
.
dumps
(
out
,
ensure_ascii
=
False
)
}
\n\n
"
except
ValueError
as
e
:
out
=
{
"error"
:
{
"message"
:
str
(
e
)}}
yield
f
"data:
{
json
.
dumps
(
out
,
ensure_ascii
=
False
)
}
\n\n
"
yield
f
"data:
{
json
.
dumps
(
out
,
ensure_ascii
=
False
)
}
\n\n
"
yield
"data: [DONE]
\n\n
"
yield
"data: [DONE]
\n\n
"
return
StreamingResponse
(
stream_results
(),
media_type
=
"text/event-stream"
)
return
StreamingResponse
(
stream_results
(),
media_type
=
"text/event-stream"
)
else
:
try
:
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
obj
).
__anext__
()
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
return
ret
return
ret
except
ValueError
as
e
:
except
ValueError
as
e
:
print
(
f
"E
rror:
{
e
}
"
)
return
JSONResponse
({
"e
rror
"
:
{
"message"
:
str
(
e
)}},
return
JSONResponse
({
"error"
:
str
(
e
)},
status_code
=
400
)
status_code
=
HTTPStatus
.
BAD_REQUEST
)
app
.
post
(
"/generate"
)(
generate_request
)
app
.
post
(
"/generate"
)(
generate_request
)
app
.
put
(
"/generate"
)(
generate_request
)
app
.
put
(
"/generate"
)(
generate_request
)
...
@@ -186,6 +188,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -186,6 +188,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if
server_args
.
api_key
and
server_args
.
api_key
!=
""
:
if
server_args
.
api_key
and
server_args
.
api_key
!=
""
:
app
.
add_middleware
(
APIKeyValidatorMiddleware
,
api_key
=
server_args
.
api_key
)
app
.
add_middleware
(
APIKeyValidatorMiddleware
,
api_key
=
server_args
.
api_key
)
# Send a warmup request
def
_wait_and_warmup
():
def
_wait_and_warmup
():
headers
=
{}
headers
=
{}
url
=
server_args
.
url
()
url
=
server_args
.
url
()
...
@@ -228,6 +231,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -228,6 +231,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
t
=
threading
.
Thread
(
target
=
_wait_and_warmup
)
t
=
threading
.
Thread
(
target
=
_wait_and_warmup
)
t
.
start
()
t
.
start
()
# Listen for requests
try
:
try
:
uvicorn
.
run
(
uvicorn
.
run
(
app
,
app
,
...
...
python/sglang/test/test_utils.py
View file @
8210ec60
...
@@ -9,7 +9,7 @@ import requests
...
@@ -9,7 +9,7 @@ import requests
from
sglang.backend.openai
import
OpenAI
from
sglang.backend.openai
import
OpenAI
from
sglang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.
srt.
utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
def
call_generate_lightllm
(
prompt
,
temperature
,
max_tokens
,
stop
=
None
,
url
=
None
):
def
call_generate_lightllm
(
prompt
,
temperature
,
max_tokens
,
stop
=
None
,
url
=
None
):
...
...
python/sglang/utils.py
View file @
8210ec60
...
@@ -93,8 +93,12 @@ def http_request(
...
@@ -93,8 +93,12 @@ def http_request(
data
=
None
data
=
None
else
:
else
:
data
=
bytes
(
dumps
(
json
),
encoding
=
"utf-8"
)
data
=
bytes
(
dumps
(
json
),
encoding
=
"utf-8"
)
resp
=
urllib
.
request
.
urlopen
(
req
,
data
=
data
,
cafile
=
verify
)
return
HttpResponse
(
resp
)
try
:
resp
=
urllib
.
request
.
urlopen
(
req
,
data
=
data
,
cafile
=
verify
)
return
HttpResponse
(
resp
)
except
urllib
.
error
.
HTTPError
as
e
:
return
HttpResponse
(
e
)
def
encode_image_base64
(
image_path
):
def
encode_image_base64
(
image_path
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment