Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6f560c76
Unverified
Commit
6f560c76
authored
Jan 29, 2024
by
Lianmin Zheng
Committed by
GitHub
Jan 29, 2024
Browse files
Improve the control of streaming and improve the first token latency in streaming (#117)
parent
cd687233
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
46 additions
and
23 deletions
+46
-23
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+7
-4
python/sglang/srt/managers/router/manager.py
python/sglang/srt/managers/router/manager.py
+1
-1
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+17
-9
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+2
-1
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-1
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+2
-2
test/srt/model/test_llama_extend.py
test/srt/model/test_llama_extend.py
+1
-1
test/srt/model/test_llava_low_api.py
test/srt/model/test_llava_low_api.py
+1
-0
test/srt/test_httpserver_decode.py
test/srt/test_httpserver_decode.py
+3
-0
test/srt/test_httpserver_decode_stream.py
test/srt/test_httpserver_decode_stream.py
+2
-0
test/srt/test_httpserver_llava.py
test/srt/test_httpserver_llava.py
+8
-3
No files found.
python/sglang/srt/managers/router/infer_batch.py
View file @
6f560c76
...
@@ -21,14 +21,17 @@ class FinishReason(Enum):
...
@@ -21,14 +21,17 @@ class FinishReason(Enum):
class
Req
:
class
Req
:
def
__init__
(
self
,
rid
):
def
__init__
(
self
,
rid
,
input_text
,
input_ids
):
self
.
rid
=
rid
self
.
rid
=
rid
self
.
input_text
=
None
self
.
input_text
=
input_text
self
.
input_ids
=
[]
self
.
input_ids
=
input_ids
self
.
output_ids
=
[]
self
.
output_ids
=
[]
# For vision input
self
.
pixel_values
=
None
self
.
pixel_values
=
None
self
.
image_size
=
None
self
.
image_size
=
None
self
.
image_offset
=
0
self
.
image_offset
=
0
self
.
sampling_params
=
None
self
.
sampling_params
=
None
self
.
return_logprob
=
False
self
.
return_logprob
=
False
self
.
logprob_start_len
=
0
self
.
logprob_start_len
=
0
...
@@ -46,7 +49,7 @@ class Req:
...
@@ -46,7 +49,7 @@ class Req:
self
.
logprob
=
None
self
.
logprob
=
None
self
.
normalized_logprob
=
None
self
.
normalized_logprob
=
None
#
f
or constrained decoding
#
F
or constrained decoding
self
.
regex_fsm
=
None
self
.
regex_fsm
=
None
self
.
regex_fsm_state
=
0
self
.
regex_fsm_state
=
0
self
.
fast_forward_map
=
None
self
.
fast_forward_map
=
None
...
...
python/sglang/srt/managers/router/manager.py
View file @
6f560c76
...
@@ -40,7 +40,7 @@ class RouterManager:
...
@@ -40,7 +40,7 @@ class RouterManager:
for
obj
in
out_pyobjs
:
for
obj
in
out_pyobjs
:
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
# async sleep for recving the subsequent request
,
and avoiding cache miss
# async sleep for rec
ei
ving the subsequent request and avoiding cache miss
if
len
(
out_pyobjs
)
!=
0
:
if
len
(
out_pyobjs
)
!=
0
:
has_finished
=
any
([
obj
.
finished
for
obj
in
out_pyobjs
])
has_finished
=
any
([
obj
.
finished
for
obj
in
out_pyobjs
])
if
has_finished
:
if
has_finished
:
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
6f560c76
...
@@ -17,8 +17,8 @@ from sglang.srt.constrained.fsm_cache import FSMCache
...
@@ -17,8 +17,8 @@ from sglang.srt.constrained.fsm_cache import FSMCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
BatchTokenIDOut
,
BatchTokenIDOut
,
TokenizedGenerateReqInput
,
FlushCacheReq
,
FlushCacheReq
,
TokenizedGenerateReqInput
,
)
)
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
,
Req
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
,
Req
from
sglang.srt.managers.router.model_runner
import
ModelRunner
from
sglang.srt.managers.router.model_runner
import
ModelRunner
...
@@ -194,6 +194,9 @@ class ModelRpcServer(rpyc.Service):
...
@@ -194,6 +194,9 @@ class ModelRpcServer(rpyc.Service):
if
self
.
running_batch
.
is_empty
():
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
self
.
running_batch
=
None
break
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
reqs
[
0
].
stream
:
break
else
:
else
:
# check the available size
# check the available size
available_size
=
(
available_size
=
(
...
@@ -208,8 +211,7 @@ class ModelRpcServer(rpyc.Service):
...
@@ -208,8 +211,7 @@ class ModelRpcServer(rpyc.Service):
)
)
if
self
.
running_batch
is
not
None
and
self
.
tp_rank
==
0
:
if
self
.
running_batch
is
not
None
and
self
.
tp_rank
==
0
:
if
self
.
decode_forward_ct
>=
20
:
if
self
.
decode_forward_ct
%
20
==
0
:
self
.
decode_forward_ct
=
0
num_used
=
self
.
max_total_num_token
-
(
num_used
=
self
.
max_total_num_token
-
(
self
.
token_to_kv_pool
.
available_size
()
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
+
self
.
tree_cache
.
evictable_size
()
...
@@ -225,11 +227,8 @@ class ModelRpcServer(rpyc.Service):
...
@@ -225,11 +227,8 @@ class ModelRpcServer(rpyc.Service):
self
,
self
,
recv_req
:
TokenizedGenerateReqInput
,
recv_req
:
TokenizedGenerateReqInput
,
):
):
req
=
Req
(
recv_req
.
rid
)
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
)
req
.
input_text
=
recv_req
.
input_text
req
.
input_ids
=
recv_req
.
input_ids
req
.
pixel_values
=
recv_req
.
pixel_values
req
.
pixel_values
=
recv_req
.
pixel_values
req
.
image_size
=
recv_req
.
image_size
if
req
.
pixel_values
is
not
None
:
if
req
.
pixel_values
is
not
None
:
pad_value
=
[
pad_value
=
[
(
recv_req
.
image_hash
)
%
self
.
model_config
.
vocab_size
,
(
recv_req
.
image_hash
)
%
self
.
model_config
.
vocab_size
,
...
@@ -240,6 +239,7 @@ class ModelRpcServer(rpyc.Service):
...
@@ -240,6 +239,7 @@ class ModelRpcServer(rpyc.Service):
req
.
input_ids
,
req
.
image_offset
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
input_ids
,
req
.
image_offset
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
input_ids
,
pad_value
,
req
.
pixel_values
.
shape
,
req
.
image_size
req
.
input_ids
,
pad_value
,
req
.
pixel_values
.
shape
,
req
.
image_size
)
)
req
.
image_size
=
recv_req
.
image_size
req
.
sampling_params
=
recv_req
.
sampling_params
req
.
sampling_params
=
recv_req
.
sampling_params
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
...
@@ -327,9 +327,11 @@ class ModelRpcServer(rpyc.Service):
...
@@ -327,9 +327,11 @@ class ModelRpcServer(rpyc.Service):
req
.
extend_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
req
.
extend_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
<
available_size
<
available_size
):
):
# Undo the insertion
delta
=
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
delta
=
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
available_size
+=
delta
available_size
+=
delta
else
:
else
:
# Add this request to the running batch
self
.
token_to_kv_pool
.
add_refs
(
req
.
prefix_indices
)
self
.
token_to_kv_pool
.
add_refs
(
req
.
prefix_indices
)
can_run_list
.
append
(
req
)
can_run_list
.
append
(
req
)
new_batch_total_tokens
+=
(
new_batch_total_tokens
+=
(
...
@@ -421,7 +423,7 @@ class ModelRpcServer(rpyc.Service):
...
@@ -421,7 +423,7 @@ class ModelRpcServer(rpyc.Service):
return
return
# Update batch tensors
# Update batch tensors
self
.
decode_forward_ct
+
=
1
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
()
# Forward
# Forward
...
@@ -454,7 +456,13 @@ class ModelRpcServer(rpyc.Service):
...
@@ -454,7 +456,13 @@ class ModelRpcServer(rpyc.Service):
unfinished_indices
.
append
(
i
)
unfinished_indices
.
append
(
i
)
if
req
.
finished
or
(
if
req
.
finished
or
(
req
.
stream
and
self
.
decode_forward_ct
%
self
.
stream_interval
==
0
(
req
.
stream
and
(
self
.
decode_forward_ct
%
self
.
stream_interval
==
0
or
len
(
req
.
output_ids
)
==
1
)
)
):
):
output_rids
.
append
(
req
.
rid
)
output_rids
.
append
(
req
.
rid
)
output_tokens
.
append
(
req
.
output_ids
)
output_tokens
.
append
(
req
.
output_ids
)
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
6f560c76
...
@@ -7,7 +7,6 @@ from typing import List
...
@@ -7,7 +7,6 @@ from typing import List
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
sglang
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.utils
import
is_multimodal_model
from
sglang.srt.utils
import
is_multimodal_model
...
@@ -16,6 +15,8 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
...
@@ -16,6 +15,8 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
from
vllm.model_executor.model_loader
import
_set_default_torch_dtype
from
vllm.model_executor.model_loader
import
_set_default_torch_dtype
from
vllm.model_executor.parallel_utils.parallel_state
import
initialize_model_parallel
from
vllm.model_executor.parallel_utils.parallel_state
import
initialize_model_parallel
import
sglang
logger
=
logging
.
getLogger
(
"model_runner"
)
logger
=
logging
.
getLogger
(
"model_runner"
)
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
6f560c76
...
@@ -18,9 +18,9 @@ from sglang.srt.hf_transformers_utils import (
...
@@ -18,9 +18,9 @@ from sglang.srt.hf_transformers_utils import (
)
)
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
BatchStrOut
,
BatchStrOut
,
FlushCacheReq
,
GenerateReqInput
,
GenerateReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
FlushCacheReq
,
)
)
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.sampling_params
import
SamplingParams
...
...
python/sglang/srt/models/llava.py
View file @
6f560c76
...
@@ -158,7 +158,7 @@ class LlavaLlamaForCausalLM(nn.Module):
...
@@ -158,7 +158,7 @@ class LlavaLlamaForCausalLM(nn.Module):
num_patch_height
,
num_patch_width
,
height
,
width
,
-
1
num_patch_height
,
num_patch_width
,
height
,
width
,
-
1
)
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
()
if
"unpad"
in
self
.
mm_patch_merge_type
:
if
"unpad"
in
self
.
mm_patch_merge_type
:
image_feature
=
image_feature
.
permute
(
image_feature
=
image_feature
.
permute
(
4
,
0
,
2
,
1
,
3
4
,
0
,
2
,
1
,
3
...
...
python/sglang/srt/server_args.py
View file @
6f560c76
...
@@ -19,7 +19,7 @@ class ServerArgs:
...
@@ -19,7 +19,7 @@ class ServerArgs:
schedule_heuristic
:
str
=
"lpm"
schedule_heuristic
:
str
=
"lpm"
schedule_conservativeness
:
float
=
1.0
schedule_conservativeness
:
float
=
1.0
random_seed
:
int
=
42
random_seed
:
int
=
42
stream_interval
:
int
=
2
stream_interval
:
int
=
8
disable_log_stats
:
bool
=
False
disable_log_stats
:
bool
=
False
log_stats_interval
:
int
=
10
log_stats_interval
:
int
=
10
log_level
:
str
=
"info"
log_level
:
str
=
"info"
...
@@ -132,7 +132,7 @@ class ServerArgs:
...
@@ -132,7 +132,7 @@ class ServerArgs:
"--stream-interval"
,
"--stream-interval"
,
type
=
int
,
type
=
int
,
default
=
ServerArgs
.
stream_interval
,
default
=
ServerArgs
.
stream_interval
,
help
=
"The interval in terms of token length
for streaming
"
,
help
=
"The interval
(or buffer size) for streaming
in terms of
the
token length
. A smaller value makes streaming smoother, while a larger value makes the throughput higher
"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--log-level"
,
"--log-level"
,
...
...
test/srt/model/test_llama_extend.py
View file @
6f560c76
...
@@ -28,7 +28,7 @@ def test_generate_worker(model_path, tp_rank, tp_size):
...
@@ -28,7 +28,7 @@ def test_generate_worker(model_path, tp_rank, tp_size):
reqs
=
[]
reqs
=
[]
for
i
in
range
(
len
(
prompts
)):
for
i
in
range
(
len
(
prompts
)):
req
=
Req
(
i
)
req
=
Req
(
i
,
None
,
None
)
req
.
input_ids
=
tokenizer
.
encode
(
prompts
[
i
])[:
cut_num
]
req
.
input_ids
=
tokenizer
.
encode
(
prompts
[
i
])[:
cut_num
]
req
.
sampling_params
=
sampling_params
req
.
sampling_params
=
sampling_params
reqs
.
append
(
req
)
reqs
.
append
(
req
)
...
...
test/srt/model/test_llava_low_api.py
View file @
6f560c76
...
@@ -112,6 +112,7 @@ def test_generate_worker(
...
@@ -112,6 +112,7 @@ def test_generate_worker(
prefill_params
=
(
prefill_params
=
(
torch
.
tensor
(
np
.
array
(
input_ids
)).
cuda
(),
torch
.
tensor
(
np
.
array
(
input_ids
)).
cuda
(),
np
.
array
(
pixel_values
),
np
.
array
(
pixel_values
),
[
None
],
[
offset
],
[
offset
],
*
params
,
*
params
,
)
)
...
...
test/srt/test_httpserver_decode.py
View file @
6f560c76
"""
"""
Usage:
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
python3 test_httpserver_decode.py
Output:
Output:
The capital of France is Paris.
\n
The capital of the United States is Washington, D.C.
\n
The capital of Canada is Ottawa.
\n
The capital of Japan is Tokyo
The capital of France is Paris.
\n
The capital of the United States is Washington, D.C.
\n
The capital of Canada is Ottawa.
\n
The capital of Japan is Tokyo
...
...
test/srt/test_httpserver_decode_stream.py
View file @
6f560c76
"""
"""
Usage:
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
python3 test_httpserver_decode_stream.py
Output:
Output:
The capital of France is Paris.
\n
The capital of the United States is Washington, D.C.
\n
The capital of Canada is Ottawa.
\n
The capital of Japan is Tokyo
The capital of France is Paris.
\n
The capital of the United States is Washington, D.C.
\n
The capital of Canada is Ottawa.
\n
The capital of Japan is Tokyo
...
...
test/srt/test_httpserver_llava.py
View file @
6f560c76
"""
"""
Usage:
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
python3 test_httpserver_llava.py
Output:
Output:
The image features a man standing on the back of a yellow taxi cab, holding
The image features a man standing on the back of a yellow taxi cab, holding
...
@@ -64,9 +66,12 @@ def test_streaming(args):
...
@@ -64,9 +66,12 @@ def test_streaming(args):
)
)
prev
=
0
prev
=
0
for
chunk
in
response
.
iter_lines
(
decode_unicode
=
False
,
delimiter
=
b
"
\0
"
):
for
chunk
in
response
.
iter_lines
(
decode_unicode
=
False
):
if
chunk
:
chunk
=
chunk
.
decode
(
"utf-8"
)
data
=
json
.
loads
(
chunk
.
decode
())
if
chunk
and
chunk
.
startswith
(
"data:"
):
if
chunk
==
"data: [DONE]"
:
break
data
=
json
.
loads
(
chunk
[
5
:].
strip
(
"
\n
"
))
output
=
data
[
"text"
].
strip
()
output
=
data
[
"text"
].
strip
()
print
(
output
[
prev
:],
end
=
""
,
flush
=
True
)
print
(
output
[
prev
:],
end
=
""
,
flush
=
True
)
prev
=
len
(
output
)
prev
=
len
(
output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment