Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8210ec60
Unverified
Commit
8210ec60
authored
May 17, 2024
by
Lianmin Zheng
Committed by
GitHub
May 17, 2024
Browse files
Improve error handling & abort disconnected requests (#449)
parent
5be9eb8a
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
198 additions
and
126 deletions
+198
-126
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+18
-10
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+7
-3
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+10
-7
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+49
-31
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+88
-58
python/sglang/srt/openai_api_adapter.py
python/sglang/srt/openai_api_adapter.py
+1
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+18
-13
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-1
python/sglang/utils.py
python/sglang/utils.py
+6
-2
No files found.
python/sglang/backend/runtime_endpoint.py
View file @
8210ec60
...
...
@@ -34,7 +34,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
self
.
model_info
=
res
.
json
()
self
.
chat_template
=
get_chat_template_by_model_path
(
...
...
@@ -50,7 +50,7 @@ class RuntimeEndpoint(BaseBackend):
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
,
)
return
res
.
status_code
==
200
self
.
_assert_success
(
res
)
def
get_server_args
(
self
):
res
=
http_request
(
...
...
@@ -58,6 +58,7 @@ class RuntimeEndpoint(BaseBackend):
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
,
)
self
.
_assert_success
(
res
)
return
res
.
json
()
def
get_chat_template
(
self
):
...
...
@@ -71,7 +72,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
def
commit_lazy_operations
(
self
,
s
:
StreamExecutor
):
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
...
...
@@ -83,7 +84,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
def
fill_image
(
self
,
s
:
StreamExecutor
):
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
...
...
@@ -95,7 +96,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
def
generate
(
self
,
...
...
@@ -133,6 +134,8 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
)
self
.
_assert_success
(
res
)
obj
=
res
.
json
()
comp
=
obj
[
"text"
]
return
comp
,
obj
[
"meta_info"
]
...
...
@@ -167,7 +170,7 @@ class RuntimeEndpoint(BaseBackend):
data
[
"stream"
]
=
True
self
.
_add_images
(
s
,
data
)
res
ponse
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
stream
=
True
,
...
...
@@ -175,10 +178,11 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
)
self
.
_assert_success
(
res
)
pos
=
0
incomplete_text
=
""
for
chunk
in
res
ponse
.
iter_lines
(
decode_unicode
=
False
):
for
chunk
in
res
.
iter_lines
(
decode_unicode
=
False
):
chunk
=
chunk
.
decode
(
"utf-8"
)
if
chunk
and
chunk
.
startswith
(
"data:"
):
if
chunk
==
"data: [DONE]"
:
...
...
@@ -211,7 +215,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
prompt_len
=
res
.
json
()[
"meta_info"
][
"prompt_tokens"
]
# Compute logprob
...
...
@@ -229,7 +233,7 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
obj
=
res
.
json
()
normalized_prompt_logprobs
=
[
r
[
"meta_info"
][
"normalized_prompt_logprob"
]
for
r
in
obj
...
...
@@ -253,9 +257,13 @@ class RuntimeEndpoint(BaseBackend):
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
self
.
_assert_success
(
res
)
def
_add_images
(
self
,
s
:
StreamExecutor
,
data
):
if
s
.
images_
:
assert
len
(
s
.
images_
)
==
1
,
"Only support one image."
data
[
"image_data"
]
=
s
.
images_
[
0
][
1
]
def
_assert_success
(
self
,
res
):
if
res
.
status_code
!=
200
:
raise
RuntimeError
(
res
.
json
())
\ No newline at end of file
python/sglang/lang/interpreter.py
View file @
8210ec60
...
...
@@ -191,7 +191,7 @@ class StreamExecutor:
self
.
variable_event
=
{}
# Dict[name: str -> event: threading.Event]
self
.
meta_info
=
{}
# Dict[name: str -> info: str]
self
.
is_finished
=
False
self
.
error
=
None
self
.
error
_
=
None
# For completion
self
.
text_
=
""
# The full text
...
...
@@ -300,6 +300,10 @@ class StreamExecutor:
self
.
sync
()
return
self
.
messages_
def
error
(
self
):
self
.
sync
()
return
self
.
error_
def
end
(
self
):
if
self
.
use_thread
:
if
self
.
worker
.
is_alive
():
...
...
@@ -338,7 +342,7 @@ class StreamExecutor:
if
self
.
stream_var_event
:
for
name
in
self
.
stream_var_event
:
self
.
stream_var_event
[
name
].
set
()
self
.
error
=
error
self
.
error
_
=
error
if
self
.
stream_text_event
:
self
.
stream_text_event
.
set
()
...
...
@@ -713,7 +717,7 @@ class ProgramState:
return
self
.
stream_executor
.
sync
()
def
error
(
self
):
return
self
.
stream_executor
.
error
return
self
.
stream_executor
.
error
()
def
text_iter
(
self
,
var_name
:
Optional
[
str
]
=
None
):
if
self
.
stream_executor
.
stream
:
...
...
python/sglang/srt/managers/io_struct.py
View file @
8210ec60
...
...
@@ -31,12 +31,9 @@ class GenerateReqInput:
def
post_init
(
self
):
if
self
.
text
is
None
:
assert
(
self
.
input_ids
is
not
None
),
"Either text or input_ids should be provided"
else
:
assert
self
.
input_ids
is
None
,
"Either text or input_ids should be provided"
if
((
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
)):
raise
ValueError
(
"Either text or input_ids should be provided."
)
if
self
.
text
is
not
None
:
is_single
=
isinstance
(
self
.
text
,
str
)
...
...
@@ -71,7 +68,8 @@ class GenerateReqInput:
if
self
.
rid
is
None
:
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
num
)]
else
:
assert
isinstance
(
self
.
rid
,
list
)
if
not
isinstance
(
self
.
rid
,
list
):
raise
ValueError
(
"The rid should be a list."
)
if
self
.
return_logprob
is
None
:
self
.
return_logprob
=
[
False
]
*
num
...
...
@@ -129,6 +127,11 @@ class FlushCacheReq:
pass
@
dataclass
class
AbortReq
:
rid
:
str
@
dataclass
class
DetokenizeReqInput
:
input_ids
:
List
[
int
]
python/sglang/srt/managers/router/model_rpc.py
View file @
8210ec60
...
...
@@ -20,6 +20,7 @@ from sglang.srt.constrained.fsm_cache import FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchTokenIDOut
,
FlushCacheReq
,
TokenizedGenerateReqInput
,
...
...
@@ -110,6 +111,8 @@ class ModelRpcServer:
get_int_token_logit_bias
(
self
.
tokenizer
,
self
.
model_config
.
vocab_size
)
)
set_random_seed
(
server_args
.
random_seed
)
# Print info
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: "
f
"max_total_num_token=
{
self
.
max_total_num_token
}
, "
...
...
@@ -160,24 +163,6 @@ class ModelRpcServer:
self
.
min_new_token_ratio
=
min
(
0.2
*
server_args
.
schedule_conservativeness
,
1.0
)
self
.
new_token_ratio_step
=
(
0.0001
,
0.05
)
# (down, up)
def
flush_cache
(
self
):
if
len
(
self
.
forward_queue
)
==
0
and
(
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
):
self
.
tree_cache
.
reset
()
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
regex_fsm_cache
.
reset
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
logger
.
info
(
"Cache flushed successfully!"
)
else
:
warnings
.
warn
(
f
"Cache not flushed because there are pending requests. "
f
"#queue-req:
{
len
(
self
.
forward_queue
)
}
, "
f
"#running-req:
{
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
}
"
)
def
exposed_step
(
self
,
recv_reqs
):
if
self
.
tp_size
!=
1
:
recv_reqs
=
obtain
(
recv_reqs
)
...
...
@@ -189,6 +174,8 @@ class ModelRpcServer:
self
.
handle_generate_request
(
recv_req
)
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
elif
isinstance
(
recv_req
,
AbortReq
):
self
.
abort_request
(
recv_req
)
else
:
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
...
...
@@ -207,9 +194,8 @@ class ModelRpcServer:
new_batch
=
self
.
get_new_fill_batch
()
if
new_batch
is
not
None
:
# Run new fill batch
# Run
a
new fill batch
self
.
forward_fill_batch
(
new_batch
)
self
.
cache_filled_batch
(
new_batch
)
if
not
new_batch
.
is_empty
():
...
...
@@ -225,14 +211,8 @@ class ModelRpcServer:
self
.
num_generated_tokens
+=
len
(
self
.
running_batch
.
reqs
)
self
.
forward_decode_batch
(
self
.
running_batch
)
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
reqs
[
0
].
stream
:
break
if
self
.
running_batch
is
not
None
and
self
.
tp_rank
==
0
:
# Print stats
if
self
.
tp_rank
==
0
:
if
self
.
decode_forward_ct
%
40
==
0
:
num_used
=
self
.
max_total_num_token
-
(
self
.
token_to_kv_pool
.
available_size
()
...
...
@@ -250,8 +230,15 @@ class ModelRpcServer:
f
"gen throughput (token/s):
{
throuhgput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
forward_queue
)
}
"
)
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
reqs
[
0
].
stream
:
break
else
:
#
c
heck the available size
#
C
heck the available size
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
...
...
@@ -295,7 +282,7 @@ class ModelRpcServer:
req
.
sampling_params
.
regex
)
# Truncate
long
prompts
# Truncate prompts
that are too long
req
.
input_ids
=
req
.
input_ids
[:
self
.
model_config
.
context_len
-
1
]
req
.
sampling_params
.
max_new_tokens
=
min
(
req
.
sampling_params
.
max_new_tokens
,
...
...
@@ -311,6 +298,7 @@ class ModelRpcServer:
):
return
None
# Compute matched prefix length
for
req
in
self
.
forward_queue
:
prefix_indices
,
last_node
=
self
.
tree_cache
.
match_prefix
(
req
.
input_ids
)
if
req
.
return_logprob
:
...
...
@@ -383,6 +371,7 @@ class ModelRpcServer:
if
len
(
can_run_list
)
==
0
:
return
None
# Print stats
if
self
.
tp_rank
==
0
:
running_req
=
(
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
...
...
@@ -410,6 +399,7 @@ class ModelRpcServer:
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
# )
# Return the new batch
new_batch
=
Batch
.
init_new
(
can_run_list
,
self
.
req_to_token_pool
,
...
...
@@ -487,7 +477,7 @@ class ModelRpcServer:
self
.
handle_finished_requests
(
batch
)
def
cache_filled_batch
(
self
,
batch
:
Batch
):
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
cpu
().
tolist
()
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
cpu
().
numpy
()
for
i
,
req
in
enumerate
(
batch
.
reqs
):
new_prefix_indices
,
new_last_node
=
self
.
tree_cache
.
cache_req
(
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
],
...
...
@@ -671,6 +661,34 @@ class ModelRpcServer:
else
:
batch
.
reqs
=
[]
def
flush_cache
(
self
):
if
len
(
self
.
forward_queue
)
==
0
and
(
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
):
self
.
tree_cache
.
reset
()
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
regex_fsm_cache
.
reset
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
logger
.
info
(
"Cache flushed successfully!"
)
else
:
warnings
.
warn
(
f
"Cache not flushed because there are pending requests. "
f
"#queue-req:
{
len
(
self
.
forward_queue
)
}
, "
f
"#running-req:
{
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
}
"
)
def
abort_request
(
self
,
recv_req
):
to_del
=
None
for
i
,
req
in
enumerate
(
self
.
forward_queue
):
if
req
.
rid
==
recv_req
.
rid
:
to_del
=
i
break
if
to_del
is
not
None
:
del
self
.
forward_queue
[
to_del
]
class
ModelRpcService
(
rpyc
.
Service
):
exposed_ModelRpcServer
=
ModelRpcServer
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
8210ec60
...
...
@@ -19,6 +19,7 @@ from sglang.srt.hf_transformers_utils import (
get_tokenizer
,
)
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchStrOut
,
FlushCacheReq
,
GenerateReqInput
,
...
...
@@ -42,52 +43,6 @@ class ReqState:
event
:
asyncio
.
Event
global
global_processor
def
init_global_processor
(
server_args
:
ServerArgs
):
global
global_processor
transformers
.
logging
.
set_verbosity_error
()
global_processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
def
get_pixel_values
(
image_data
,
image_aspect_ratio
=
None
,
image_grid_pinpoints
=
None
,
processor
=
None
):
try
:
processor
=
processor
or
global_processor
image
,
image_size
=
load_image
(
image_data
)
if
image_size
!=
None
:
image_hash
=
hash
(
image_data
)
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
]
for
_
in
range
(
len
(
pixel_values
)):
pixel_values
[
_
]
=
pixel_values
[
_
].
astype
(
np
.
float16
)
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
return
pixel_values
,
image_hash
,
image_size
else
:
image_hash
=
hash
(
image_data
)
if
image_aspect_ratio
==
"pad"
:
image
=
expand2square
(
image
,
tuple
(
int
(
x
*
255
)
for
x
in
processor
.
image_processor
.
image_mean
),
)
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
][
0
]
elif
image_aspect_ratio
==
"anyres"
:
pixel_values
=
process_anyres_image
(
image
,
processor
.
image_processor
,
image_grid_pinpoints
)
else
:
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
][
0
]
pixel_values
=
pixel_values
.
astype
(
np
.
float16
)
return
pixel_values
,
image_hash
,
image
.
size
except
Exception
:
print
(
"Exception in TokenizerManager:
\n
"
+
get_exception_traceback
())
class
TokenizerManager
:
def
__init__
(
self
,
...
...
@@ -154,10 +109,11 @@ class TokenizerManager:
image_data
,
aspect_ratio
,
grid_pinpoints
,
self
.
processor
)
async
def
generate_request
(
self
,
obj
:
GenerateReqInput
):
async
def
generate_request
(
self
,
obj
:
GenerateReqInput
,
request
=
None
):
if
self
.
to_create_loop
:
await
self
.
create_handle_loop
()
self
.
create_handle_loop
()
obj
.
post_init
()
is_single
=
obj
.
is_single
if
is_single
:
rid
=
obj
.
rid
...
...
@@ -170,7 +126,7 @@ class TokenizerManager:
if
len
(
input_ids
)
>=
self
.
context_len
:
raise
ValueError
(
f
"The input (
{
len
(
input_ids
)
}
tokens) is longer than the "
f
"model's context length (
{
self
.
context_len
}
tokens)"
f
"model's context length (
{
self
.
context_len
}
tokens)
.
"
)
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
)
...
...
@@ -208,7 +164,14 @@ class TokenizerManager:
self
.
rid_to_state
[
rid
]
=
state
while
True
:
await
event
.
wait
()
try
:
await
asyncio
.
wait_for
(
event
.
wait
(),
timeout
=
5
)
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
out
=
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
obj
.
return_logprob
,
...
...
@@ -226,7 +189,8 @@ class TokenizerManager:
break
event
.
clear
()
else
:
assert
obj
.
stream
is
False
if
obj
.
stream
:
raise
ValueError
(
"Do not support stream for batch mode."
)
if
obj
.
input_ids
is
None
:
bs
=
len
(
obj
.
text
)
...
...
@@ -276,7 +240,18 @@ class TokenizerManager:
for
i
in
range
(
bs
):
rid
=
obj
.
rid
[
i
]
state
=
self
.
rid_to_state
[
rid
]
await
state
.
event
.
wait
()
while
True
:
try
:
await
asyncio
.
wait_for
(
state
.
event
.
wait
(),
timeout
=
5
)
break
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
for
rid
in
obj
.
rid
:
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
output_list
.
append
(
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
...
...
@@ -290,11 +265,16 @@ class TokenizerManager:
yield
output_list
async
def
flush_cache
(
self
):
flush_cache_
req
=
FlushCacheReq
()
self
.
send_to_router
.
send_pyobj
(
flush_cache_
req
)
def
flush_cache
(
self
):
req
=
FlushCacheReq
()
self
.
send_to_router
.
send_pyobj
(
req
)
async
def
create_handle_loop
(
self
):
def
abort_request
(
self
,
rid
):
del
self
.
rid_to_state
[
rid
]
req
=
AbortReq
(
rid
)
self
.
send_to_router
.
send_pyobj
(
req
)
def
create_handle_loop
(
self
):
self
.
to_create_loop
=
False
loop
=
asyncio
.
get_event_loop
()
loop
.
create_task
(
self
.
handle_loop
())
...
...
@@ -305,17 +285,20 @@ class TokenizerManager:
if
isinstance
(
recv_obj
,
BatchStrOut
):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
if
state
is
None
:
continue
recv_obj
.
meta_info
[
i
][
"id"
]
=
rid
out_dict
=
{
"text"
:
recv_obj
.
output_str
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
state
=
self
.
rid_to_state
[
rid
]
state
.
out_list
.
append
(
out_dict
)
state
.
finished
=
recv_obj
.
finished
[
i
]
state
.
event
.
set
()
else
:
raise
ValueError
(
f
"Invalid object:
{
recv_obj
}
"
)
raise
ValueError
(
f
"Invalid object:
{
recv_obj
}
.
"
)
def
convert_logprob_style
(
self
,
ret
,
return_logprob
,
top_logprobs_num
,
return_text_in_logprobs
...
...
@@ -356,3 +339,50 @@ class TokenizerManager:
if
t
:
top_logprobs
[
i
]
=
self
.
detokenize_logprob_tokens
(
t
,
decode_to_text
)
return
top_logprobs
global
global_processor
def
init_global_processor
(
server_args
:
ServerArgs
):
global
global_processor
transformers
.
logging
.
set_verbosity_error
()
global_processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
def
get_pixel_values
(
image_data
,
image_aspect_ratio
=
None
,
image_grid_pinpoints
=
None
,
processor
=
None
):
try
:
processor
=
processor
or
global_processor
image
,
image_size
=
load_image
(
image_data
)
if
image_size
!=
None
:
image_hash
=
hash
(
image_data
)
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
]
for
_
in
range
(
len
(
pixel_values
)):
pixel_values
[
_
]
=
pixel_values
[
_
].
astype
(
np
.
float16
)
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
return
pixel_values
,
image_hash
,
image_size
else
:
image_hash
=
hash
(
image_data
)
if
image_aspect_ratio
==
"pad"
:
image
=
expand2square
(
image
,
tuple
(
int
(
x
*
255
)
for
x
in
processor
.
image_processor
.
image_mean
),
)
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
][
0
]
elif
image_aspect_ratio
==
"anyres"
:
pixel_values
=
process_anyres_image
(
image
,
processor
.
image_processor
,
image_grid_pinpoints
)
else
:
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
][
0
]
pixel_values
=
pixel_values
.
astype
(
np
.
float16
)
return
pixel_values
,
image_hash
,
image
.
size
except
Exception
:
print
(
"Exception in TokenizerManager:
\n
"
+
get_exception_traceback
())
\ No newline at end of file
python/sglang/srt/openai_api_adapter.py
View file @
8210ec60
...
...
@@ -335,7 +335,7 @@ def to_openai_style_logprobs(
ret_logprobs
.
tokens
.
append
(
token_text
)
ret_logprobs
.
token_logprobs
.
append
(
logprob
)
# Not
S
upported yet
# Not
s
upported yet
ret_logprobs
.
text_offset
.
append
(
-
1
)
def
append_top_logprobs
(
top_logprobs
):
...
...
python/sglang/srt/server.py
View file @
8210ec60
...
...
@@ -10,6 +10,7 @@ import sys
import
threading
import
time
from
typing
import
List
,
Optional
,
Union
from
http
import
HTTPStatus
# Fix a bug of Python threading
setattr
(
threading
,
"_register_atexit"
,
lambda
*
args
,
**
kwargs
:
None
)
...
...
@@ -73,7 +74,7 @@ async def get_server_args():
@
app
.
get
(
"/flush_cache"
)
async
def
flush_cache
():
await
tokenizer_manager
.
flush_cache
()
tokenizer_manager
.
flush_cache
()
return
Response
(
content
=
"Cache flushed.
\n
Please check backend logs for more details. "
"(When there are running or waiting requests, the operation will not be performed.)
\n
"
,
...
...
@@ -81,24 +82,25 @@ async def flush_cache():
)
async
def
generate_request
(
obj
:
GenerateReqInput
):
obj
.
post_init
()
async
def
generate_request
(
obj
:
GenerateReqInput
,
request
:
Request
):
if
obj
.
stream
:
async
def
stream_results
():
async
for
out
in
tokenizer_manager
.
generate_request
(
obj
):
try
:
async
for
out
in
tokenizer_manager
.
generate_request
(
obj
,
request
):
yield
f
"data:
{
json
.
dumps
(
out
,
ensure_ascii
=
False
)
}
\n\n
"
except
ValueError
as
e
:
out
=
{
"error"
:
{
"message"
:
str
(
e
)}}
yield
f
"data:
{
json
.
dumps
(
out
,
ensure_ascii
=
False
)
}
\n\n
"
yield
"data: [DONE]
\n\n
"
return
StreamingResponse
(
stream_results
(),
media_type
=
"text/event-stream"
)
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
obj
).
__anext__
()
return
ret
except
ValueError
as
e
:
print
(
f
"E
rror:
{
e
}
"
)
return
JSONResponse
({
"error"
:
str
(
e
)},
status_code
=
400
)
else
:
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
return
ret
except
ValueError
as
e
:
return
JSONResponse
({
"e
rror
"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
app
.
post
(
"/generate"
)(
generate_request
)
app
.
put
(
"/generate"
)(
generate_request
)
...
...
@@ -186,6 +188,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if
server_args
.
api_key
and
server_args
.
api_key
!=
""
:
app
.
add_middleware
(
APIKeyValidatorMiddleware
,
api_key
=
server_args
.
api_key
)
# Send a warmup request
def
_wait_and_warmup
():
headers
=
{}
url
=
server_args
.
url
()
...
...
@@ -228,6 +231,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
t
=
threading
.
Thread
(
target
=
_wait_and_warmup
)
t
.
start
()
# Listen for requests
try
:
uvicorn
.
run
(
app
,
...
...
python/sglang/test/test_utils.py
View file @
8210ec60
...
...
@@ -9,7 +9,7 @@ import requests
from
sglang.backend.openai
import
OpenAI
from
sglang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.global_config
import
global_config
from
sglang.
srt.
utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
def
call_generate_lightllm
(
prompt
,
temperature
,
max_tokens
,
stop
=
None
,
url
=
None
):
...
...
python/sglang/utils.py
View file @
8210ec60
...
...
@@ -93,8 +93,12 @@ def http_request(
data
=
None
else
:
data
=
bytes
(
dumps
(
json
),
encoding
=
"utf-8"
)
resp
=
urllib
.
request
.
urlopen
(
req
,
data
=
data
,
cafile
=
verify
)
return
HttpResponse
(
resp
)
try
:
resp
=
urllib
.
request
.
urlopen
(
req
,
data
=
data
,
cafile
=
verify
)
return
HttpResponse
(
resp
)
except
urllib
.
error
.
HTTPError
as
e
:
return
HttpResponse
(
e
)
def
encode_image_base64
(
image_path
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment