Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2b809788
Unverified
Commit
2b809788
authored
Oct 26, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 26, 2024
Browse files
Provide an argument to set the maximum batch size for cuda graph (#1809)
parent
9d6fb084
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
10 deletions
+25
-10
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+11
-6
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+6
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
test/srt/test_large_max_new_tokens.py
test/srt/test_large_max_new_tokens.py
+1
-1
No files found.
python/sglang/srt/managers/schedule_policy.py
View file @
2b809788
...
...
@@ -30,7 +30,9 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
# This can prevent the server from being too conservative.
# Note that this only clips the estimation in the scheduler but does not change the stop
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
CLIP_MAX_NEW_TOKENS
=
int
(
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS"
,
"4096"
))
CLIP_MAX_NEW_TOKENS_ESTIMATION
=
int
(
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION"
,
"4096"
)
)
class
SchedulePolicy
:
...
...
@@ -146,7 +148,7 @@ class PrefillAdder:
[
min
(
(
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
)),
CLIP_MAX_NEW_TOKENS
,
CLIP_MAX_NEW_TOKENS
_ESTIMATION
,
)
*
self
.
new_token_ratio
for
r
in
running_batch
.
reqs
...
...
@@ -186,7 +188,7 @@ class PrefillAdder:
len
(
req
.
prefix_indices
),
req
.
extend_input_len
,
(
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
)
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
_ESTIMATION
)
if
not
truncated
else
0
),
...
...
@@ -258,7 +260,7 @@ class PrefillAdder:
self
.
_prefill_one_req
(
0
,
req
.
extend_input_len
,
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
),
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
_ESTIMATION
),
)
else
:
# Chunked prefill
...
...
@@ -276,7 +278,7 @@ class PrefillAdder:
return
self
.
add_one_req_ignore_eos
(
req
)
total_tokens
=
req
.
extend_input_len
+
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
_ESTIMATION
)
input_tokens
=
req
.
extend_input_len
prefix_len
=
len
(
req
.
prefix_indices
)
...
...
@@ -302,7 +304,10 @@ class PrefillAdder:
self
.
_prefill_one_req
(
prefix_len
,
input_tokens
,
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
),
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS_ESTIMATION
,
),
)
else
:
# Chunked prefill
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
2b809788
...
...
@@ -113,12 +113,15 @@ class CudaGraphRunner:
self
.
is_encoder_decoder
=
self
.
model_runner
.
model_config
.
is_encoder_decoder
# Batch sizes to capture
if
self
.
model_runner
.
server_args
.
disable_cuda_graph_padding
:
if
model_runner
.
server_args
.
disable_cuda_graph_padding
:
self
.
capture_bs
=
list
(
range
(
1
,
32
))
+
[
64
,
128
]
else
:
self
.
capture_bs
=
[
1
,
2
,
3
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
21
)]
self
.
capture_bs
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
21
)]
self
.
capture_bs
=
[
bs
for
bs
in
self
.
capture_bs
if
bs
<=
model_runner
.
req_to_token_pool
.
size
bs
for
bs
in
self
.
capture_bs
if
bs
<=
model_runner
.
req_to_token_pool
.
size
and
bs
<=
model_runner
.
server_args
.
max_cuda_graph_bs
]
self
.
compile_bs
=
(
[
...
...
python/sglang/srt/server_args.py
View file @
2b809788
...
...
@@ -120,6 +120,7 @@ class ServerArgs:
enable_mixed_chunk
:
bool
=
False
enable_torch_compile
:
bool
=
False
max_torch_compile_bs
:
int
=
32
max_cuda_graph_bs
:
int
=
160
torchao_config
:
str
=
""
enable_p2p_check
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
...
...
@@ -624,6 +625,12 @@ class ServerArgs:
default
=
ServerArgs
.
max_torch_compile_bs
,
help
=
"Set the maximum batch size when using torch compile."
,
)
parser
.
add_argument
(
"--max-cuda-graph-bs"
,
type
=
int
,
default
=
ServerArgs
.
max_cuda_graph_bs
,
help
=
"Set the maximum batch size for cuda graph."
,
)
parser
.
add_argument
(
"--torchao-config"
,
type
=
str
,
...
...
test/srt/test_large_max_new_tokens.py
View file @
2b809788
...
...
@@ -34,7 +34,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
other_args
=
(
"--max-total-token"
,
"1024"
,
"--context-len"
,
"8192"
),
env
=
{
"SGLANG_CLIP_MAX_NEW_TOKENS"
:
"256"
,
**
os
.
environ
},
env
=
{
"SGLANG_CLIP_MAX_NEW_TOKENS
_ESTIMATION
"
:
"256"
,
**
os
.
environ
},
return_stdout_stderr
=
(
cls
.
stdout
,
cls
.
stderr
),
)
cls
.
base_url
+=
"/v1"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment