Unverified Commit 86e0dde5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve the user control of new_token_ratio (#1811)

parent 2b809788
...@@ -14,9 +14,15 @@ class GlobalConfig: ...@@ -14,9 +14,15 @@ class GlobalConfig:
self.default_backend = None self.default_backend = None
# Runtime constants: New generation token ratio estimation # Runtime constants: New generation token ratio estimation
self.init_new_token_ratio = 0.7 self.default_init_new_token_ratio = float(
self.base_min_new_token_ratio = 0.1 os.environ.get("SGLANG_INIT_NEW_TOKEN_RATIO", 0.7)
self.new_token_ratio_decay = 0.001 )
self.default_min_new_token_ratio_factor = float(
os.environ.get("SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR", 0.14)
)
self.default_new_token_ratio_decay_steps = float(
os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
)
# Runtime constants: others # Runtime constants: others
self.retract_decode_steps = 20 self.retract_decode_steps = 20
......
...@@ -254,13 +254,22 @@ class Scheduler: ...@@ -254,13 +254,22 @@ class Scheduler:
assert ( assert (
server_args.schedule_conservativeness >= 0 server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness" ), "Invalid schedule_conservativeness"
self.min_new_token_ratio = min(
global_config.base_min_new_token_ratio self.init_new_token_ratio = min(
global_config.default_init_new_token_ratio
* server_args.schedule_conservativeness, * server_args.schedule_conservativeness,
1.0, 1.0,
) )
self.new_token_ratio = self.min_new_token_ratio self.min_new_token_ratio = min(
self.new_token_ratio_decay = global_config.new_token_ratio_decay self.init_new_token_ratio
* global_config.default_min_new_token_ratio_factor,
1.0,
)
self.new_token_ratio_decay = (
self.init_new_token_ratio - self.min_new_token_ratio
) / global_config.default_new_token_ratio_decay_steps
self.new_token_ratio = self.init_new_token_ratio
self.batch_is_full = False self.batch_is_full = False
# Init profiler # Init profiler
...@@ -307,7 +316,7 @@ class Scheduler: ...@@ -307,7 +316,7 @@ class Scheduler:
self.process_batch_result(batch, result) self.process_batch_result(batch, result)
else: else:
self.check_memory() self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch self.last_batch = batch
...@@ -334,7 +343,7 @@ class Scheduler: ...@@ -334,7 +343,7 @@ class Scheduler:
self.process_batch_result(tmp_batch, tmp_result) self.process_batch_result(tmp_batch, tmp_result)
elif batch is None: elif batch is None:
self.check_memory() self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch self.last_batch = batch
......
...@@ -121,13 +121,13 @@ class CudaGraphRunner: ...@@ -121,13 +121,13 @@ class CudaGraphRunner:
bs bs
for bs in self.capture_bs for bs in self.capture_bs
if bs <= model_runner.req_to_token_pool.size if bs <= model_runner.req_to_token_pool.size
and bs <= model_runner.server_args.max_cuda_graph_bs and bs <= model_runner.server_args.cuda_graph_max_bs
] ]
self.compile_bs = ( self.compile_bs = (
[ [
bs bs
for bs in self.capture_bs for bs in self.capture_bs
if bs <= self.model_runner.server_args.max_torch_compile_bs if bs <= self.model_runner.server_args.torch_compile_max_bs
] ]
if self.use_torch_compile if self.use_torch_compile
else [] else []
......
...@@ -119,8 +119,8 @@ class ServerArgs: ...@@ -119,8 +119,8 @@ class ServerArgs:
enable_overlap_schedule: bool = False enable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False enable_mixed_chunk: bool = False
enable_torch_compile: bool = False enable_torch_compile: bool = False
max_torch_compile_bs: int = 32 torch_compile_max_bs: int = 32
max_cuda_graph_bs: int = 160 cuda_graph_max_bs: int = 160
torchao_config: str = "" torchao_config: str = ""
enable_p2p_check: bool = False enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False triton_attention_reduce_in_fp32: bool = False
...@@ -620,15 +620,15 @@ class ServerArgs: ...@@ -620,15 +620,15 @@ class ServerArgs:
help="Optimize the model with torch.compile. Experimental feature.", help="Optimize the model with torch.compile. Experimental feature.",
) )
parser.add_argument( parser.add_argument(
"--max-torch-compile-bs", "--torch-compile-max-bs",
type=int, type=int,
default=ServerArgs.max_torch_compile_bs, default=ServerArgs.torch_compile_max_bs,
help="Set the maximum batch size when using torch compile.", help="Set the maximum batch size when using torch compile.",
) )
parser.add_argument( parser.add_argument(
"--max-cuda-graph-bs", "--cuda-graph-max-bs",
type=int, type=int,
default=ServerArgs.max_cuda_graph_bs, default=ServerArgs.cuda_graph_max_bs,
help="Set the maximum batch size for cuda graph.", help="Set the maximum batch size for cuda graph.",
) )
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment