Unverified Commit 2b809788 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Provide an argument to set the maximum batch size for cuda graph (#1809)

parent 9d6fb084
......@@ -30,7 +30,9 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
# This can prevent the server from being too conservative.
# Note that this only clips the estimation in the scheduler but does not change the stop
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
)
class SchedulePolicy:
......@@ -146,7 +148,7 @@ class PrefillAdder:
[
min(
(r.sampling_params.max_new_tokens - len(r.output_ids)),
CLIP_MAX_NEW_TOKENS,
CLIP_MAX_NEW_TOKENS_ESTIMATION,
)
* self.new_token_ratio
for r in running_batch.reqs
......@@ -186,7 +188,7 @@ class PrefillAdder:
len(req.prefix_indices),
req.extend_input_len,
(
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION)
if not truncated
else 0
),
......@@ -258,7 +260,7 @@ class PrefillAdder:
self._prefill_one_req(
0,
req.extend_input_len,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
)
else:
# Chunked prefill
......@@ -276,7 +278,7 @@ class PrefillAdder:
return self.add_one_req_ignore_eos(req)
total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
)
input_tokens = req.extend_input_len
prefix_len = len(req.prefix_indices)
......@@ -302,7 +304,10 @@ class PrefillAdder:
self._prefill_one_req(
prefix_len,
input_tokens,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
min(
req.sampling_params.max_new_tokens,
CLIP_MAX_NEW_TOKENS_ESTIMATION,
),
)
else:
# Chunked prefill
......
......@@ -113,12 +113,15 @@ class CudaGraphRunner:
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
# Batch sizes to capture
if self.model_runner.server_args.disable_cuda_graph_padding:
if model_runner.server_args.disable_cuda_graph_padding:
self.capture_bs = list(range(1, 32)) + [64, 128]
else:
self.capture_bs = [1, 2, 3, 4] + [i * 8 for i in range(1, 21)]
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.capture_bs = [
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
bs
for bs in self.capture_bs
if bs <= model_runner.req_to_token_pool.size
and bs <= model_runner.server_args.max_cuda_graph_bs
]
self.compile_bs = (
[
......
......@@ -120,6 +120,7 @@ class ServerArgs:
enable_mixed_chunk: bool = False
enable_torch_compile: bool = False
max_torch_compile_bs: int = 32
max_cuda_graph_bs: int = 160
torchao_config: str = ""
enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False
......@@ -624,6 +625,12 @@ class ServerArgs:
default=ServerArgs.max_torch_compile_bs,
help="Set the maximum batch size when using torch compile.",
)
parser.add_argument(
"--max-cuda-graph-bs",
type=int,
default=ServerArgs.max_cuda_graph_bs,
help="Set the maximum batch size for cuda graph.",
)
parser.add_argument(
"--torchao-config",
type=str,
......
......@@ -34,7 +34,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=("--max-total-token", "1024", "--context-len", "8192"),
env={"SGLANG_CLIP_MAX_NEW_TOKENS": "256", **os.environ},
env={"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION": "256", **os.environ},
return_stdout_stderr=(cls.stdout, cls.stderr),
)
cls.base_url += "/v1"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment