Unverified Commit 2b809788 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Provide an argument to set the maximum batch size for cuda graph (#1809)

parent 9d6fb084
...@@ -30,7 +30,9 @@ from sglang.srt.mem_cache.radix_cache import TreeNode ...@@ -30,7 +30,9 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
# This can prevent the server from being too conservative. # This can prevent the server from being too conservative.
# Note that this only clips the estimation in the scheduler but does not change the stop # Note that this only clips the estimation in the scheduler but does not change the stop
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens. # condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096")) CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
)
class SchedulePolicy: class SchedulePolicy:
...@@ -146,7 +148,7 @@ class PrefillAdder: ...@@ -146,7 +148,7 @@ class PrefillAdder:
[ [
min( min(
(r.sampling_params.max_new_tokens - len(r.output_ids)), (r.sampling_params.max_new_tokens - len(r.output_ids)),
CLIP_MAX_NEW_TOKENS, CLIP_MAX_NEW_TOKENS_ESTIMATION,
) )
* self.new_token_ratio * self.new_token_ratio
for r in running_batch.reqs for r in running_batch.reqs
...@@ -186,7 +188,7 @@ class PrefillAdder: ...@@ -186,7 +188,7 @@ class PrefillAdder:
len(req.prefix_indices), len(req.prefix_indices),
req.extend_input_len, req.extend_input_len,
( (
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS) min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION)
if not truncated if not truncated
else 0 else 0
), ),
...@@ -258,7 +260,7 @@ class PrefillAdder: ...@@ -258,7 +260,7 @@ class PrefillAdder:
self._prefill_one_req( self._prefill_one_req(
0, 0,
req.extend_input_len, req.extend_input_len,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS), min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
) )
else: else:
# Chunked prefill # Chunked prefill
...@@ -276,7 +278,7 @@ class PrefillAdder: ...@@ -276,7 +278,7 @@ class PrefillAdder:
return self.add_one_req_ignore_eos(req) return self.add_one_req_ignore_eos(req)
total_tokens = req.extend_input_len + min( total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
) )
input_tokens = req.extend_input_len input_tokens = req.extend_input_len
prefix_len = len(req.prefix_indices) prefix_len = len(req.prefix_indices)
...@@ -302,7 +304,10 @@ class PrefillAdder: ...@@ -302,7 +304,10 @@ class PrefillAdder:
self._prefill_one_req( self._prefill_one_req(
prefix_len, prefix_len,
input_tokens, input_tokens,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS), min(
req.sampling_params.max_new_tokens,
CLIP_MAX_NEW_TOKENS_ESTIMATION,
),
) )
else: else:
# Chunked prefill # Chunked prefill
......
...@@ -113,12 +113,15 @@ class CudaGraphRunner: ...@@ -113,12 +113,15 @@ class CudaGraphRunner:
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
# Batch sizes to capture # Batch sizes to capture
if self.model_runner.server_args.disable_cuda_graph_padding: if model_runner.server_args.disable_cuda_graph_padding:
self.capture_bs = list(range(1, 32)) + [64, 128] self.capture_bs = list(range(1, 32)) + [64, 128]
else: else:
self.capture_bs = [1, 2, 3, 4] + [i * 8 for i in range(1, 21)] self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.capture_bs = [ self.capture_bs = [
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size bs
for bs in self.capture_bs
if bs <= model_runner.req_to_token_pool.size
and bs <= model_runner.server_args.max_cuda_graph_bs
] ]
self.compile_bs = ( self.compile_bs = (
[ [
......
...@@ -120,6 +120,7 @@ class ServerArgs: ...@@ -120,6 +120,7 @@ class ServerArgs:
enable_mixed_chunk: bool = False enable_mixed_chunk: bool = False
enable_torch_compile: bool = False enable_torch_compile: bool = False
max_torch_compile_bs: int = 32 max_torch_compile_bs: int = 32
max_cuda_graph_bs: int = 160
torchao_config: str = "" torchao_config: str = ""
enable_p2p_check: bool = False enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False triton_attention_reduce_in_fp32: bool = False
...@@ -624,6 +625,12 @@ class ServerArgs: ...@@ -624,6 +625,12 @@ class ServerArgs:
default=ServerArgs.max_torch_compile_bs, default=ServerArgs.max_torch_compile_bs,
help="Set the maximum batch size when using torch compile.", help="Set the maximum batch size when using torch compile.",
) )
parser.add_argument(
"--max-cuda-graph-bs",
type=int,
default=ServerArgs.max_cuda_graph_bs,
help="Set the maximum batch size for cuda graph.",
)
parser.add_argument( parser.add_argument(
"--torchao-config", "--torchao-config",
type=str, type=str,
......
...@@ -34,7 +34,7 @@ class TestLargeMaxNewTokens(unittest.TestCase): ...@@ -34,7 +34,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key, api_key=cls.api_key,
other_args=("--max-total-token", "1024", "--context-len", "8192"), other_args=("--max-total-token", "1024", "--context-len", "8192"),
env={"SGLANG_CLIP_MAX_NEW_TOKENS": "256", **os.environ}, env={"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION": "256", **os.environ},
return_stdout_stderr=(cls.stdout, cls.stderr), return_stdout_stderr=(cls.stdout, cls.stderr),
) )
cls.base_url += "/v1" cls.base_url += "/v1"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment