Unverified Commit 80407b04 authored by YAMY's avatar YAMY Committed by GitHub
Browse files

Fix: Dynamic RoPE Cache Expansion to Prevent Position-ID Out-of-Bounds in...

Fix: Dynamic RoPE Cache Expansion to Prevent Position-ID Out-of-Bounds in EAGLE + Long-Sequence Workloads (#10788)
parent b288f4f4
...@@ -222,6 +222,11 @@ class Envs: ...@@ -222,6 +222,11 @@ class Envs:
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096) SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256) SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)
# RoPE cache configuration
SGLANG_SPEC_EXPANSION_SAFETY_FACTOR = EnvInt(2)
SGLANG_ROPE_CACHE_SAFETY_MARGIN = EnvInt(256)
SGLANG_ROPE_CACHE_ALIGN = EnvInt(128)
# Overlap Spec V2 # Overlap Spec V2
SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False) SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False)
......
...@@ -147,6 +147,36 @@ class RotaryEmbedding(CustomOp): ...@@ -147,6 +147,36 @@ class RotaryEmbedding(CustomOp):
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
return cache return cache
def _ensure_cos_sin_cache_length(self, needed_max_pos: int):
"""Ensure cos_sin_cache length > needed_max_pos."""
cur_len = int(self.cos_sin_cache.shape[0])
if needed_max_pos < cur_len:
return
# Align to 128 to reduce realloc frequency
new_len = ((needed_max_pos + 128) // 128) * 128
device = self.cos_sin_cache.device
dtype = self.cos_sin_cache.dtype
# Compute inv_freq on same device
inv_freq = self._compute_inv_freq(self.base).to(device=device)
# Incremental computation for new positions only
start = cur_len
t_new = torch.arange(start, new_len, dtype=inv_freq.dtype, device=device)
if t_new.numel() == 0:
return
freqs_new = torch.einsum("i,j->ij", t_new, inv_freq)
cos_new = freqs_new.cos()
sin_new = freqs_new.sin()
new_rows = torch.cat((cos_new, sin_new), dim=-1).to(dtype=dtype)
# Update cache with new rows
self.cos_sin_cache = torch.cat((self.cos_sin_cache, new_rows), dim=0).to(
device=device, dtype=dtype
)
def forward_native( def forward_native(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
......
...@@ -140,6 +140,7 @@ from sglang.srt.utils import ( ...@@ -140,6 +140,7 @@ from sglang.srt.utils import (
log_info_on_rank0, log_info_on_rank0,
monkey_patch_p2p_access_check, monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
reserve_rope_cache_for_long_sequences,
set_cuda_arch, set_cuda_arch,
slow_rank_detector, slow_rank_detector,
) )
...@@ -898,6 +899,15 @@ class ModelRunner: ...@@ -898,6 +899,15 @@ class ModelRunner:
f"mem usage={self.weight_load_mem_usage:.2f} GB." f"mem usage={self.weight_load_mem_usage:.2f} GB."
) )
# Pre-expand RoPE cache before CUDA Graph capture
reserve_rope_cache_for_long_sequences(
self.model,
self.server_args,
self.model_config,
self.req_to_token_pool,
logger,
)
if self.server_args.elastic_ep_backend == "mooncake": if self.server_args.elastic_ep_backend == "mooncake":
# Mooncake does not support `monitored_barrier` # Mooncake does not support `monitored_barrier`
dist.barrier(group=get_tp_group().cpu_group) dist.barrier(group=get_tp_group().cpu_group)
......
...@@ -3460,3 +3460,61 @@ def cached_triton_kernel(key_fn=None): ...@@ -3460,3 +3460,61 @@ def cached_triton_kernel(key_fn=None):
return CachedKernel(fn, key_fn) return CachedKernel(fn, key_fn)
return decorator return decorator
def reserve_rope_cache_for_long_sequences(
model, server_args, model_config, req_to_token_pool=None, logger=None
):
"""Pre-expand RoPE cache for long sequences and speculative decoding."""
from sglang.srt.environ import envs
if logger is None:
import logging
logger = logging.getLogger(__name__)
SAFETY_FACTOR = envs.SGLANG_SPEC_EXPANSION_SAFETY_FACTOR.value
MARGIN = envs.SGLANG_ROPE_CACHE_SAFETY_MARGIN.value
ALIGN = envs.SGLANG_ROPE_CACHE_ALIGN.value
# 1) Estimate base context upper bound
base_ctx = (
getattr(server_args, "context_length", None)
or getattr(model_config, "context_len", None)
or getattr(model_config, "max_model_len", None)
or getattr(model_config.hf_text_config, "max_position_embeddings", None)
or 2048
)
# 2) Runtime input capacity (including extra_len from req_to_token_pool)
inferred_cap = getattr(req_to_token_pool, "max_context_len", None) or base_ctx
# 3) Speculative decoding expansion
steps = int(getattr(server_args, "speculative_num_steps", 0) or 0)
draft = int(getattr(server_args, "speculative_num_draft_tokens", 0) or 0)
reserve = inferred_cap + steps * draft * SAFETY_FACTOR + MARGIN
# 4) Align to reduce reallocation frequency
reserve = (reserve + ALIGN - 1) // ALIGN * ALIGN
logger.info(
f"RoPE cache reserve={reserve} (base={base_ctx}, cap={inferred_cap}, steps={steps}, draft={draft}, k={SAFETY_FACTOR}, margin={MARGIN})"
)
# Recursively expand all RoPE layers
def reserve_rope_cache_recursive(module):
for child in module.children():
if hasattr(child, "_ensure_cos_sin_cache_length") and hasattr(
child, "cos_sin_cache"
):
old_len = child.cos_sin_cache.shape[0]
child._ensure_cos_sin_cache_length(reserve - 1)
new_len = child.cos_sin_cache.shape[0]
if new_len > old_len:
logger.info(
f"Expanded RoPE cache from {old_len} to {new_len} positions"
)
else:
reserve_rope_cache_recursive(child)
reserve_rope_cache_recursive(model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment