Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2b809788
Unverified
Commit
2b809788
authored
Oct 26, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 26, 2024
Browse files
Provide an argument to set the maximum batch size for cuda graph (#1809)
parent
9d6fb084
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
10 deletions
+25
-10
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+11
-6
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+6
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
test/srt/test_large_max_new_tokens.py
test/srt/test_large_max_new_tokens.py
+1
-1
No files found.
python/sglang/srt/managers/schedule_policy.py
View file @
2b809788
...
@@ -30,7 +30,9 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
...
@@ -30,7 +30,9 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
# This can prevent the server from being too conservative.
# This can prevent the server from being too conservative.
# Note that this only clips the estimation in the scheduler but does not change the stop
# Note that this only clips the estimation in the scheduler but does not change the stop
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
CLIP_MAX_NEW_TOKENS
=
int
(
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS"
,
"4096"
))
CLIP_MAX_NEW_TOKENS_ESTIMATION
=
int
(
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION"
,
"4096"
)
)
class
SchedulePolicy
:
class
SchedulePolicy
:
...
@@ -146,7 +148,7 @@ class PrefillAdder:
...
@@ -146,7 +148,7 @@ class PrefillAdder:
[
[
min
(
min
(
(
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
)),
(
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
)),
CLIP_MAX_NEW_TOKENS
,
CLIP_MAX_NEW_TOKENS
_ESTIMATION
,
)
)
*
self
.
new_token_ratio
*
self
.
new_token_ratio
for
r
in
running_batch
.
reqs
for
r
in
running_batch
.
reqs
...
@@ -186,7 +188,7 @@ class PrefillAdder:
...
@@ -186,7 +188,7 @@ class PrefillAdder:
len
(
req
.
prefix_indices
),
len
(
req
.
prefix_indices
),
req
.
extend_input_len
,
req
.
extend_input_len
,
(
(
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
)
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
_ESTIMATION
)
if
not
truncated
if
not
truncated
else
0
else
0
),
),
...
@@ -258,7 +260,7 @@ class PrefillAdder:
...
@@ -258,7 +260,7 @@ class PrefillAdder:
self
.
_prefill_one_req
(
self
.
_prefill_one_req
(
0
,
0
,
req
.
extend_input_len
,
req
.
extend_input_len
,
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
),
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
_ESTIMATION
),
)
)
else
:
else
:
# Chunked prefill
# Chunked prefill
...
@@ -276,7 +278,7 @@ class PrefillAdder:
...
@@ -276,7 +278,7 @@ class PrefillAdder:
return
self
.
add_one_req_ignore_eos
(
req
)
return
self
.
add_one_req_ignore_eos
(
req
)
total_tokens
=
req
.
extend_input_len
+
min
(
total_tokens
=
req
.
extend_input_len
+
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
_ESTIMATION
)
)
input_tokens
=
req
.
extend_input_len
input_tokens
=
req
.
extend_input_len
prefix_len
=
len
(
req
.
prefix_indices
)
prefix_len
=
len
(
req
.
prefix_indices
)
...
@@ -302,7 +304,10 @@ class PrefillAdder:
...
@@ -302,7 +304,10 @@ class PrefillAdder:
self
.
_prefill_one_req
(
self
.
_prefill_one_req
(
prefix_len
,
prefix_len
,
input_tokens
,
input_tokens
,
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
),
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS_ESTIMATION
,
),
)
)
else
:
else
:
# Chunked prefill
# Chunked prefill
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
2b809788
...
@@ -113,12 +113,15 @@ class CudaGraphRunner:
...
@@ -113,12 +113,15 @@ class CudaGraphRunner:
self
.
is_encoder_decoder
=
self
.
model_runner
.
model_config
.
is_encoder_decoder
self
.
is_encoder_decoder
=
self
.
model_runner
.
model_config
.
is_encoder_decoder
# Batch sizes to capture
# Batch sizes to capture
if
self
.
model_runner
.
server_args
.
disable_cuda_graph_padding
:
if
model_runner
.
server_args
.
disable_cuda_graph_padding
:
self
.
capture_bs
=
list
(
range
(
1
,
32
))
+
[
64
,
128
]
self
.
capture_bs
=
list
(
range
(
1
,
32
))
+
[
64
,
128
]
else
:
else
:
self
.
capture_bs
=
[
1
,
2
,
3
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
21
)]
self
.
capture_bs
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
21
)]
self
.
capture_bs
=
[
self
.
capture_bs
=
[
bs
for
bs
in
self
.
capture_bs
if
bs
<=
model_runner
.
req_to_token_pool
.
size
bs
for
bs
in
self
.
capture_bs
if
bs
<=
model_runner
.
req_to_token_pool
.
size
and
bs
<=
model_runner
.
server_args
.
max_cuda_graph_bs
]
]
self
.
compile_bs
=
(
self
.
compile_bs
=
(
[
[
...
...
python/sglang/srt/server_args.py
View file @
2b809788
...
@@ -120,6 +120,7 @@ class ServerArgs:
...
@@ -120,6 +120,7 @@ class ServerArgs:
enable_mixed_chunk
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_torch_compile
:
bool
=
False
max_torch_compile_bs
:
int
=
32
max_torch_compile_bs
:
int
=
32
max_cuda_graph_bs
:
int
=
160
torchao_config
:
str
=
""
torchao_config
:
str
=
""
enable_p2p_check
:
bool
=
False
enable_p2p_check
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
...
@@ -624,6 +625,12 @@ class ServerArgs:
...
@@ -624,6 +625,12 @@ class ServerArgs:
default
=
ServerArgs
.
max_torch_compile_bs
,
default
=
ServerArgs
.
max_torch_compile_bs
,
help
=
"Set the maximum batch size when using torch compile."
,
help
=
"Set the maximum batch size when using torch compile."
,
)
)
parser
.
add_argument
(
"--max-cuda-graph-bs"
,
type
=
int
,
default
=
ServerArgs
.
max_cuda_graph_bs
,
help
=
"Set the maximum batch size for cuda graph."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--torchao-config"
,
"--torchao-config"
,
type
=
str
,
type
=
str
,
...
...
test/srt/test_large_max_new_tokens.py
View file @
2b809788
...
@@ -34,7 +34,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
...
@@ -34,7 +34,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
api_key
=
cls
.
api_key
,
other_args
=
(
"--max-total-token"
,
"1024"
,
"--context-len"
,
"8192"
),
other_args
=
(
"--max-total-token"
,
"1024"
,
"--context-len"
,
"8192"
),
env
=
{
"SGLANG_CLIP_MAX_NEW_TOKENS"
:
"256"
,
**
os
.
environ
},
env
=
{
"SGLANG_CLIP_MAX_NEW_TOKENS
_ESTIMATION
"
:
"256"
,
**
os
.
environ
},
return_stdout_stderr
=
(
cls
.
stdout
,
cls
.
stderr
),
return_stdout_stderr
=
(
cls
.
stdout
,
cls
.
stderr
),
)
)
cls
.
base_url
+=
"/v1"
cls
.
base_url
+=
"/v1"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment