Unverified Commit 4cd64b8e authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Auto adjust new ratio (#708)

parent 01d66ae2
...@@ -16,9 +16,9 @@ class GlobalConfig: ...@@ -16,9 +16,9 @@ class GlobalConfig:
self.wait_for_new_request_delay = 0.0006 self.wait_for_new_request_delay = 0.0006
# Runtime constants: New generation token ratio estimation # Runtime constants: New generation token ratio estimation
self.base_new_token_ratio = 0.4 self.init_new_token_ratio = 0.7
self.base_min_new_token_ratio = 0.2 self.base_min_new_token_ratio = 0.2
self.new_token_ratio_decay = 0.0001 self.new_token_ratio_decay = 0.001
self.new_token_ratio_recovery = 0.05 self.new_token_ratio_recovery = 0.05
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync. # Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
...@@ -27,6 +27,7 @@ class GlobalConfig: ...@@ -27,6 +27,7 @@ class GlobalConfig:
# Runtime constants: others # Runtime constants: others
self.num_continue_decode_steps = 10 self.num_continue_decode_steps = 10
self.retract_decode_steps = 20
self.flashinfer_workspace_size = 192 * 1024 * 1024 self.flashinfer_workspace_size = 192 * 1024 * 1024
# Output tokenization configs # Output tokenization configs
......
...@@ -9,6 +9,7 @@ import numpy as np ...@@ -9,6 +9,7 @@ import numpy as np
import torch import torch
from flashinfer.sampling import top_k_top_p_sampling_from_probs from flashinfer.sampling import top_k_top_p_sampling_from_probs
from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.managers.controller.radix_cache import RadixCache from sglang.srt.managers.controller.radix_cache import RadixCache
...@@ -431,7 +432,8 @@ class Batch: ...@@ -431,7 +432,8 @@ class Batch:
def retract_decode(self): def retract_decode(self):
sorted_indices = [i for i in range(len(self.reqs))] sorted_indices = [i for i in range(len(self.reqs))]
# TODO(lsyin): improve the priority of retraction
# TODO(lsyin): improve retraction policy for radix cache
sorted_indices.sort( sorted_indices.sort(
key=lambda i: ( key=lambda i: (
len(self.reqs[i].output_ids), len(self.reqs[i].output_ids),
...@@ -443,7 +445,17 @@ class Batch: ...@@ -443,7 +445,17 @@ class Batch:
retracted_reqs = [] retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy() seq_lens_cpu = self.seq_lens.cpu().numpy()
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
while self.token_to_kv_pool.available_size() < len(self.reqs): while (
self.token_to_kv_pool.available_size()
< len(sorted_indices) * global_config.retract_decode_steps
):
if len(sorted_indices) == 1:
# Corner case: only one request left
assert (
self.token_to_kv_pool.available_size() > 0
), "No space left for only one request"
break
idx = sorted_indices.pop() idx = sorted_indices.pop()
req = self.reqs[idx] req = self.reqs[idx]
retracted_reqs.append(req) retracted_reqs.append(req)
...@@ -468,7 +480,16 @@ class Batch: ...@@ -468,7 +480,16 @@ class Batch:
self.filter_batch(sorted_indices) self.filter_batch(sorted_indices)
return retracted_reqs # Reqs in batch are filtered
total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
new_estimate_ratio = (
total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
) / total_max_new_tokens
new_estimate_ratio = min(1.0, new_estimate_ratio)
return retracted_reqs, new_estimate_ratio
def check_for_jump_forward(self, model_runner): def check_for_jump_forward(self, model_runner):
jump_forward_reqs = [] jump_forward_reqs = []
......
...@@ -228,6 +228,7 @@ class ModelTpServer: ...@@ -228,6 +228,7 @@ class ModelTpServer:
break break
else: else:
self.check_memory() self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
def print_stats(self): def print_stats(self):
num_used = self.max_total_num_tokens - ( num_used = self.max_total_num_tokens - (
...@@ -536,9 +537,10 @@ class ModelTpServer: ...@@ -536,9 +537,10 @@ class ModelTpServer:
# Check if decode out of memory # Check if decode out of memory
if not batch.check_decode_mem(): if not batch.check_decode_mem():
old_ratio = self.new_token_ratio old_ratio = self.new_token_ratio
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
retracted_reqs = batch.retract_decode() retracted_reqs, new_token_ratio = batch.retract_decode()
self.new_token_ratio = new_token_ratio
logger.info( logger.info(
"decode out of memory happened, " "decode out of memory happened, "
f"#retracted_reqs: {len(retracted_reqs)}, " f"#retracted_reqs: {len(retracted_reqs)}, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment