You need to sign in or sign up before continuing.
Unverified Commit 86e0dde5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve the user control of new_token_ratio (#1811)

parent 2b809788
......@@ -14,9 +14,15 @@ class GlobalConfig:
self.default_backend = None
# Runtime constants: New generation token ratio estimation
self.init_new_token_ratio = 0.7
self.base_min_new_token_ratio = 0.1
self.new_token_ratio_decay = 0.001
self.default_init_new_token_ratio = float(
os.environ.get("SGLANG_INIT_NEW_TOKEN_RATIO", 0.7)
)
self.default_min_new_token_ratio_factor = float(
os.environ.get("SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR", 0.14)
)
self.default_new_token_ratio_decay_steps = float(
os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
)
# Runtime constants: others
self.retract_decode_steps = 20
......
......@@ -254,13 +254,22 @@ class Scheduler:
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
self.min_new_token_ratio = min(
global_config.base_min_new_token_ratio
self.init_new_token_ratio = min(
global_config.default_init_new_token_ratio
* server_args.schedule_conservativeness,
1.0,
)
self.new_token_ratio = self.min_new_token_ratio
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.min_new_token_ratio = min(
self.init_new_token_ratio
* global_config.default_min_new_token_ratio_factor,
1.0,
)
self.new_token_ratio_decay = (
self.init_new_token_ratio - self.min_new_token_ratio
) / global_config.default_new_token_ratio_decay_steps
self.new_token_ratio = self.init_new_token_ratio
self.batch_is_full = False
# Init profiler
......@@ -307,7 +316,7 @@ class Scheduler:
self.process_batch_result(batch, result)
else:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
......@@ -334,7 +343,7 @@ class Scheduler:
self.process_batch_result(tmp_batch, tmp_result)
elif batch is None:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
......
......@@ -121,13 +121,13 @@ class CudaGraphRunner:
bs
for bs in self.capture_bs
if bs <= model_runner.req_to_token_pool.size
and bs <= model_runner.server_args.max_cuda_graph_bs
and bs <= model_runner.server_args.cuda_graph_max_bs
]
self.compile_bs = (
[
bs
for bs in self.capture_bs
if bs <= self.model_runner.server_args.max_torch_compile_bs
if bs <= self.model_runner.server_args.torch_compile_max_bs
]
if self.use_torch_compile
else []
......
......@@ -119,8 +119,8 @@ class ServerArgs:
enable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_torch_compile: bool = False
max_torch_compile_bs: int = 32
max_cuda_graph_bs: int = 160
torch_compile_max_bs: int = 32
cuda_graph_max_bs: int = 160
torchao_config: str = ""
enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False
......@@ -620,15 +620,15 @@ class ServerArgs:
help="Optimize the model with torch.compile. Experimental feature.",
)
parser.add_argument(
"--max-torch-compile-bs",
"--torch-compile-max-bs",
type=int,
default=ServerArgs.max_torch_compile_bs,
default=ServerArgs.torch_compile_max_bs,
help="Set the maximum batch size when using torch compile.",
)
parser.add_argument(
"--max-cuda-graph-bs",
"--cuda-graph-max-bs",
type=int,
default=ServerArgs.max_cuda_graph_bs,
default=ServerArgs.cuda_graph_max_bs,
help="Set the maximum batch size for cuda graph.",
)
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment