Unverified Commit eb6c2c16 authored by tarinkk's avatar tarinkk Committed by GitHub
Browse files

Hybrid kv cache for LLaMA4 (#6563)


Co-authored-by: default avatarCheng Wan <54331508+ch-wan@users.noreply.github.com>
Co-authored-by: default avatartarinkk <rt572@physics.rutger.edu>
Co-authored-by: default avatartarinkk <rt572@rutgers.physics.edu>
Co-authored-by: default avatarHanming Lu <69857889+hanming-lu@users.noreply.github.com>
parent 357921aa
...@@ -16,10 +16,11 @@ python3 -m sglang.launch_server --model-path meta-llama/Llama-4-Scout-17B-16E-In ...@@ -16,10 +16,11 @@ python3 -m sglang.launch_server --model-path meta-llama/Llama-4-Scout-17B-16E-In
### Configuration Tips ### Configuration Tips
- **OOM Mitigation**: Adjust `--context-length` to avoid a GPU out-of-memory issue. For the Scout model, we recommend setting this value up to 1M on 8\*H100 and up to 2.5M on 8\*H200. For the Maverick model, we don't need to set context length on 8\*H200. - **OOM Mitigation**: Adjust `--context-length` to avoid a GPU out-of-memory issue. For the Scout model, we recommend setting this value up to 1M on 8\*H100 and up to 2.5M on 8\*H200. For the Maverick model, we don't need to set context length on 8\*H200. When hybrid kv cache is enabled, `--context-length` can be set up to 5M on 8\*H100 and up to 10M on 8\*H200 for the Scout model.
- **Chat Template**: Add `--chat-template llama-4` for chat completion tasks. - **Chat Template**: Add `--chat-template llama-4` for chat completion tasks.
- **Enable Multi-Modal**: Add `--enable-multimodal` for multi-modal capabilities. - **Enable Multi-Modal**: Add `--enable-multimodal` for multi-modal capabilities.
- **Enable Hybrid-KVCache**: Add `--hybrid-kvcache-ratio` for hybrid kv cache. Details can be seen in [this PR](https://github.com/sgl-project/sglang/pull/6563)
## Benchmarking Results ## Benchmarking Results
......
...@@ -59,6 +59,7 @@ class ModelConfig: ...@@ -59,6 +59,7 @@ class ModelConfig:
quantization: Optional[str] = None, quantization: Optional[str] = None,
override_config_file: Optional[str] = None, override_config_file: Optional[str] = None,
is_draft_model: bool = False, is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
impl: Union[str, ModelImpl] = ModelImpl.AUTO, impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None: ) -> None:
...@@ -86,6 +87,18 @@ class ModelConfig: ...@@ -86,6 +87,18 @@ class ModelConfig:
self.attention_chunk_size = getattr( self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None self.hf_text_config, "attention_chunk_size", None
) )
self.is_hybrid = is_hybrid_model(
self.hf_config.architectures,
hybrid_kvcache_ratio=hybrid_kvcache_ratio,
context_length=context_length,
attention_chunk_size=self.attention_chunk_size,
)
if self.is_hybrid is not None:
self.swa_attention_layer_ids, self.full_attention_layer_ids = (
get_hybrid_layer_ids(
self.hf_config.architectures, self.hf_text_config.num_hidden_layers
)
)
if enable_multimodal is None: if enable_multimodal is None:
mm_disabled_models = [ mm_disabled_models = [
...@@ -264,6 +277,7 @@ class ModelConfig: ...@@ -264,6 +277,7 @@ class ModelConfig:
enable_multimodal=server_args.enable_multimodal, enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype, dtype=server_args.dtype,
quantization=server_args.quantization, quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
impl=server_args.impl, impl=server_args.impl,
**kwargs, **kwargs,
) )
...@@ -633,3 +647,36 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: ...@@ -633,3 +647,36 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1: if scale <= 1:
return 1.0 return 1.0
return 0.1 * mscale * math.log(scale) + 1.0 return 0.1 * mscale * math.log(scale) + 1.0
def is_hybrid_model(
model_architectures: List[str],
hybrid_kvcache_ratio: Optional[float],
context_length: Optional[int],
attention_chunk_size: Optional[int],
):
if hybrid_kvcache_ratio is None:
return None
elif (
hybrid_kvcache_ratio > 0
and model_architectures[0] == "Llama4ForConditionalGeneration"
and context_length > attention_chunk_size
):
return hybrid_kvcache_ratio
else:
return None
def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
if "Llama4ForConditionalGeneration" in model_architectures:
swa_attention_layer_ids = [
i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
]
full_attention_layer_ids = [
i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
]
else:
raise ValueError(
"get_hybrid_layer_ids is only implemented for Llama4ForConditionalGeneration"
)
return swa_attention_layer_ids, full_attention_layer_ids
...@@ -433,9 +433,7 @@ class DecodePreallocQueue: ...@@ -433,9 +433,7 @@ class DecodePreallocQueue:
else 0 else 0
) )
available_size = self.token_to_kv_pool_allocator.available_size() allocatable_tokens = self.token_to_kv_pool_allocator.available_size() - max(
allocatable_tokens = available_size - max(
# preserve some space for future decode # preserve some space for future decode
self.num_reserved_decode_tokens self.num_reserved_decode_tokens
* ( * (
......
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
from sglang.srt.configs.model_config import AttentionArch from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import SWAKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
...@@ -320,6 +321,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -320,6 +321,11 @@ class FlashAttentionBackend(AttentionBackend):
self.page_size = model_runner.page_size self.page_size = model_runner.page_size
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
self.is_hybrid = model_runner.is_hybrid
if self.is_hybrid:
self.full_to_swa_index_mapping = (
model_runner.token_to_kv_pool.full_to_swa_index_mapping
)
self.topk = model_runner.server_args.speculative_eagle_topk or 0 self.topk = model_runner.server_args.speculative_eagle_topk or 0
self.speculative_num_steps = speculative_num_steps self.speculative_num_steps = speculative_num_steps
self.speculative_num_draft_tokens = ( self.speculative_num_draft_tokens = (
...@@ -428,7 +434,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -428,7 +434,7 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
# TODO: we need to test this part for llama 4 eagle case # TODO: we need to test this part for llama 4 eagle case
self._init_local_attn_metadata(metadata, device) self._init_local_attn_metadata(forward_batch, metadata, device)
elif forward_batch.forward_mode.is_target_verify(): elif forward_batch.forward_mode.is_target_verify():
if self.topk <= 1: if self.topk <= 1:
metadata.cache_seqlens_int32 = ( metadata.cache_seqlens_int32 = (
...@@ -456,7 +462,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -456,7 +462,7 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
self._init_local_attn_metadata(metadata, device) self._init_local_attn_metadata(forward_batch, metadata, device)
else: else:
metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32) metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
metadata.max_seq_len_q = self.speculative_num_draft_tokens metadata.max_seq_len_q = self.speculative_num_draft_tokens
...@@ -575,7 +581,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -575,7 +581,7 @@ class FlashAttentionBackend(AttentionBackend):
# Setup local attention if enabled # Setup local attention if enabled
if forward_batch.forward_mode == ForwardMode.EXTEND: if forward_batch.forward_mode == ForwardMode.EXTEND:
self._init_local_attn_metadata(metadata, device) self._init_local_attn_metadata(forward_batch, metadata, device)
# Encoder metadata for cross attention # Encoder metadata for cross attention
if forward_batch.encoder_lens is not None: if forward_batch.encoder_lens is not None:
...@@ -1588,7 +1594,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1588,7 +1594,7 @@ class FlashAttentionBackend(AttentionBackend):
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: torch.Tensor = None, out_cache_loc: Optional[torch.Tensor] = None,
): ):
"""Initialize forward metadata for replaying CUDA graph.""" """Initialize forward metadata for replaying CUDA graph."""
seq_lens = seq_lens[:bs] seq_lens = seq_lens[:bs]
...@@ -1673,7 +1679,10 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1673,7 +1679,10 @@ class FlashAttentionBackend(AttentionBackend):
self.page_size, self.page_size,
) )
self._update_local_attn_metadata_for_replay(metadata, bs) self._update_local_attn_metadata_for_replay(
metadata,
bs,
)
elif forward_mode.is_target_verify(): elif forward_mode.is_target_verify():
if self.topk <= 1: if self.topk <= 1:
metadata = self.target_verify_metadata[bs] metadata = self.target_verify_metadata[bs]
...@@ -1829,7 +1838,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1829,7 +1838,9 @@ class FlashAttentionBackend(AttentionBackend):
"""Get the fill value for sequence length in CUDA graph.""" """Get the fill value for sequence length in CUDA graph."""
return 1 return 1
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device): def _init_local_attn_metadata(
self, forwardbatch: ForwardBatch, metadata: FlashAttentionMetadata, device
):
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled.""" """Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
if self.attention_chunk_size is None: if self.attention_chunk_size is None:
metadata.local_attn_metadata = None metadata.local_attn_metadata = None
...@@ -1837,7 +1848,12 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1837,7 +1848,12 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_q = metadata.cu_seqlens_q cu_seqlens_q = metadata.cu_seqlens_q
cache_seqlens_int32 = metadata.cache_seqlens_int32 cache_seqlens_int32 = metadata.cache_seqlens_int32
page_table = metadata.page_table if self.is_hybrid:
page_table = self.full_to_swa_index_mapping[metadata.page_table].to(
torch.int32
)
else:
page_table = metadata.page_table
if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None: if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
metadata.local_attn_metadata = None metadata.local_attn_metadata = None
return return
...@@ -1923,7 +1939,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1923,7 +1939,9 @@ class FlashAttentionBackend(AttentionBackend):
) )
def _update_local_attn_metadata_for_replay( def _update_local_attn_metadata_for_replay(
self, metadata: FlashAttentionMetadata, bs: int self,
metadata: FlashAttentionMetadata,
bs: int,
): ):
"""Update preallocated local attention metadata in-place before CUDA graph replay.""" """Update preallocated local attention metadata in-place before CUDA graph replay."""
if self.attention_chunk_size is None: if self.attention_chunk_size is None:
...@@ -1954,7 +1972,12 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1954,7 +1972,12 @@ class FlashAttentionBackend(AttentionBackend):
# Without this slicing, the pre-allocated page_table may contain zeros or invalid indices # Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
# beyond the actual sequence length, leading to incorrect attention calculations # beyond the actual sequence length, leading to incorrect attention calculations
max_seq_len = int(seqlens.max().item()) max_seq_len = int(seqlens.max().item())
sliced_page_table = metadata.page_table[:bs, :max_seq_len] if self.is_hybrid:
sliced_page_table = self.full_to_swa_index_mapping[
metadata.page_table[:bs, :max_seq_len]
].to(torch.int32)
else:
sliced_page_table = metadata.page_table[:bs, :max_seq_len]
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy() cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
seqlens_np = seqlens.cpu().numpy() seqlens_np = seqlens.cpu().numpy()
......
...@@ -56,7 +56,7 @@ from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank ...@@ -56,7 +56,7 @@ from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.multimodal import gpu_tensor_hash from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.metrics.collector import TimeStats from sglang.srt.metrics.collector import TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
...@@ -485,6 +485,9 @@ class Req: ...@@ -485,6 +485,9 @@ class Req:
# for corss-endoder model # for corss-endoder model
self.token_type_ids = token_type_ids self.token_type_ids = token_type_ids
# The length of KV that have been removed in local attention chunked prefill
self.evicted_seqlen_local = 0
# Sampling info # Sampling info
if isinstance(sampling_params.custom_params, dict): if isinstance(sampling_params.custom_params, dict):
sampling_params = copy.copy(sampling_params) sampling_params = copy.copy(sampling_params)
...@@ -1191,6 +1194,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1191,6 +1194,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.req_to_token_pool.write( self.req_to_token_pool.write(
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
) )
if isinstance(self.tree_cache, SWAChunkCache):
self.tree_cache.evict(
req, pre_len, self.model_config.attention_chunk_size
)
# If input_embeds are available, store them # If input_embeds are available, store them
if req.input_embeds is not None: if req.input_embeds is not None:
...@@ -1383,7 +1390,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1383,7 +1390,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
* buf_multiplier * buf_multiplier
* self.token_to_kv_pool_allocator.page_size * self.token_to_kv_pool_allocator.page_size
) )
if self.token_to_kv_pool_allocator.available_size() >= tokens_required: if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
return True return True
...@@ -1564,6 +1570,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1564,6 +1570,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens.add_(1) self.seq_lens.add_(1)
self.seq_lens_sum += bs self.seq_lens_sum += bs
# free memory
if isinstance(self.tree_cache, SWAChunkCache):
for req in self.reqs:
self.tree_cache.evict(
req, req.seqlen - 1, self.model_config.attention_chunk_size
)
# Allocate memory # Allocate memory
if self.token_to_kv_pool_allocator.page_size == 1: if self.token_to_kv_pool_allocator.page_size == 1:
self.out_cache_loc = self.alloc_token_slots(bs) self.out_cache_loc = self.alloc_token_slots(bs)
...@@ -1798,7 +1811,6 @@ class ModelWorkerBatch: ...@@ -1798,7 +1811,6 @@ class ModelWorkerBatch:
seq_lens: torch.Tensor seq_lens: torch.Tensor
# The indices of output tokens in the token_to_kv_pool_allocator # The indices of output tokens in the token_to_kv_pool_allocator
out_cache_loc: torch.Tensor out_cache_loc: torch.Tensor
# The sequence length tensor on CPU # The sequence length tensor on CPU
seq_lens_cpu: Optional[torch.Tensor] seq_lens_cpu: Optional[torch.Tensor]
seq_lens_sum: int seq_lens_sum: int
......
...@@ -126,7 +126,8 @@ from sglang.srt.managers.session_controller import Session ...@@ -126,7 +126,8 @@ from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
...@@ -570,7 +571,11 @@ class Scheduler( ...@@ -570,7 +571,11 @@ class Scheduler(
server_args.chunked_prefill_size is not None server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache and server_args.disable_radix_cache
): ):
self.tree_cache = ChunkCache( if self.model_config.is_hybrid:
ChunkCacheClass = SWAChunkCache
else:
ChunkCacheClass = ChunkCache
self.tree_cache = ChunkCacheClass(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size, page_size=self.page_size,
...@@ -1283,9 +1288,8 @@ class Scheduler( ...@@ -1283,9 +1288,8 @@ class Scheduler(
self.last_input_throughput = self.last_prefill_tokens / gap_latency self.last_input_throughput = self.last_prefill_tokens / gap_latency
self.last_prefill_tokens = adder.log_input_tokens self.last_prefill_tokens = adder.log_input_tokens
num_used = self.max_total_num_tokens - ( usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
self.token_to_kv_pool_allocator.available_size() self.tree_cache.evictable_size()
+ self.tree_cache.evictable_size()
) )
num_new_seq = len(can_run_list) num_new_seq = len(can_run_list)
...@@ -1294,7 +1298,7 @@ class Scheduler( ...@@ -1294,7 +1298,7 @@ class Scheduler(
f"#new-seq: {num_new_seq}, " f"#new-seq: {num_new_seq}, "
f"#new-token: {adder.log_input_tokens}, " f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, " f"#cached-token: {adder.log_hit_tokens}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"{usage_msg}"
) )
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
...@@ -1337,9 +1341,8 @@ class Scheduler( ...@@ -1337,9 +1341,8 @@ class Scheduler(
self.last_gen_throughput = self.num_generated_tokens / gap_latency self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0 self.num_generated_tokens = 0
num_running_reqs = len(batch.reqs) num_running_reqs = len(batch.reqs)
num_used = self.max_total_num_tokens - ( usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
self.token_to_kv_pool_allocator.available_size() self.tree_cache.evictable_size()
+ self.tree_cache.evictable_size()
) )
if RECORD_STEP_TIME: if RECORD_STEP_TIME:
...@@ -1347,12 +1350,7 @@ class Scheduler( ...@@ -1347,12 +1350,7 @@ class Scheduler(
gap_latency / self.server_args.decode_log_interval gap_latency / self.server_args.decode_log_interval
) )
msg = ( msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}"
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
)
if self.spec_algorithm.is_none(): if self.spec_algorithm.is_none():
spec_accept_length = 0 spec_accept_length = 0
...@@ -1390,10 +1388,11 @@ class Scheduler( ...@@ -1390,10 +1388,11 @@ class Scheduler(
self._publish_kv_events() self._publish_kv_events()
def check_memory(self): def check_memory(self):
available_size = ( if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
self.token_to_kv_pool_allocator.available_size() available_token_size = self.token_to_kv_pool_allocator.full_available_size()
+ self.tree_cache.evictable_size() else:
) available_token_size = self.token_to_kv_pool_allocator.available_size()
available_size = available_token_size + self.tree_cache.evictable_size()
protected_size = self.tree_cache.protected_size() protected_size = self.tree_cache.protected_size()
memory_leak = available_size != ( memory_leak = available_size != (
self.max_total_num_tokens self.max_total_num_tokens
...@@ -1404,7 +1403,7 @@ class Scheduler( ...@@ -1404,7 +1403,7 @@ class Scheduler(
msg = ( msg = (
"token_to_kv_pool_allocator memory leak detected! " "token_to_kv_pool_allocator memory leak detected! "
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n" f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
f"{self.token_to_kv_pool_allocator.available_size()=}\n" f"{available_token_size=}\n"
f"{self.tree_cache.evictable_size()=}\n" f"{self.tree_cache.evictable_size()=}\n"
) )
raise ValueError(msg) raise ValueError(msg)
......
...@@ -20,12 +20,14 @@ Page-aligned memory pool. ...@@ -20,12 +20,14 @@ Page-aligned memory pool.
""" """
import abc import abc
import weakref
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.mem_cache.memory_pool import SWAKVPool
from sglang.srt.utils import get_bool_env_var, next_power_of_2 from sglang.srt.utils import get_bool_env_var, next_power_of_2
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -55,6 +57,11 @@ class BaseTokenToKVPoolAllocator(abc.ABC): ...@@ -55,6 +57,11 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
def debug_print(self) -> str: def debug_print(self) -> str:
return "" return ""
def log_usage(self, evictable_size: int = 0):
num_used = self.size - (self.available_size() + evictable_size)
msg = f"#token: {num_used}, token usage: {num_used / self.size:.2f}, "
return msg, num_used
def available_size(self): def available_size(self):
return len(self.free_pages) * self.page_size return len(self.free_pages) * self.page_size
...@@ -146,6 +153,128 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -146,6 +153,128 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices) return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
"""Allocator for SWA hybrid KV cache."""
def __init__(
self,
size: int,
size_swa: int,
dtype: torch.dtype,
device: str,
kvcache: SWAKVPool,
):
super().__init__(size, 1, dtype, device, kvcache)
assert isinstance(kvcache, SWAKVPool)
self._size_full = size
self._size_swa = size_swa
self.full_attn_allocator = TokenToKVPoolAllocator(
size,
dtype,
device,
kvcache.full_kv_pool,
)
self.swa_attn_allocator = TokenToKVPoolAllocator(
size_swa,
dtype,
device,
kvcache.swa_kv_pool,
)
self.full_to_swa_index_mapping = torch.empty(
size + size_swa + 1,
dtype=torch.int64,
device=device,
)
self.clear()
self._kvcache.full_to_swa_index_mapping = self.full_to_swa_index_mapping
def available_size(self):
return min(self.full_available_size(), self.swa_available_size())
def full_available_size(self):
return self.full_attn_allocator.available_size()
def swa_available_size(self):
return self.swa_attn_allocator.available_size()
@property
def size_full(self):
return self._size_full
@property
def size_swa(self):
return self._size_swa
def debug_print(self) -> str:
msg = ""
msg += f"#swa-available-size: {self.swa_attn_allocator.available_size()}, "
msg += (
f"#full-attn-available-size: {self.full_attn_allocator.available_size()}, "
)
return msg
def log_usage(self, swa_evictable_size: int = 0, full_evictable_size: int = 0):
used_full = self.size_full - (self.full_available_size() + full_evictable_size)
used_swa = self.size_swa - (self.swa_available_size() + swa_evictable_size)
msg = (
f"#token: full={used_full}, swa={used_swa}, "
f"token usage: full={used_full / self.size_full:.2f}, "
f"swa={used_swa / self.size_swa:.2f}, "
)
return msg, used_full
def get_kvcache(self):
return self._kvcache
def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor):
assert self.full_to_swa_index_mapping is not None
return self.full_to_swa_index_mapping[kv_indices].to(torch.int32)
def alloc(self, need_size: int):
if need_size > self.full_attn_allocator.available_size():
return None
if need_size > self.swa_attn_allocator.available_size():
return None
alloc_full_indices = self.full_attn_allocator.alloc(need_size)
alloc_swa_indices = self.swa_attn_allocator.alloc(need_size)
self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices
return alloc_full_indices
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
self.full_attn_allocator.free(free_index)
self.free_swa(free_index)
else:
self.free_group.append(free_index)
assert (
self.full_attn_allocator.available_size() <= self.full_attn_allocator.size
)
assert self.swa_attn_allocator.available_size() <= self.swa_attn_allocator.size
def free_swa(self, free_index: torch.Tensor):
swa_indices = self.full_to_swa_index_mapping[free_index]
swa_indices = swa_indices[swa_indices > 0]
self.swa_attn_allocator.free(swa_indices)
self.full_to_swa_index_mapping[free_index] = 0
def backup_state(self):
raise NotImplementedError
def restore_state(self, state):
raise NotImplementedError
def clear(self):
self.swa_attn_allocator.clear()
self.full_attn_allocator.clear()
self.full_to_swa_index_mapping.fill_(0)
self.is_in_free_group = False
self.free_group = []
@triton.jit @triton.jit
def alloc_extend_kernel( def alloc_extend_kernel(
pre_lens_ptr, pre_lens_ptr,
......
...@@ -2,11 +2,14 @@ from __future__ import annotations ...@@ -2,11 +2,14 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled.""" """Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
import torch import torch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
...@@ -63,3 +66,32 @@ class ChunkCache(BasePrefixCache): ...@@ -63,3 +66,32 @@ class ChunkCache(BasePrefixCache):
def pretty_print(self): def pretty_print(self):
return "" return ""
class SWAChunkCache(ChunkCache):
"""ChunkCache with support for hybrid KV cache operations."""
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: SWATokenToKVPoolAllocator,
page_size: int,
):
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size)
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
def evict(
self,
req: Req,
prelen: int,
attention_chunk_size: int,
):
if prelen >= req.evicted_seqlen_local + attention_chunk_size:
new_evicted_seqlen_local = attention_chunk_size * (
prelen // attention_chunk_size
)
free_slots = self.req_to_token_pool.req_to_token[
req.req_pool_idx, req.evicted_seqlen_local : new_evicted_seqlen_local
]
self.token_to_kv_pool_allocator.free_swa(free_slots)
req.evicted_seqlen_local = new_evicted_seqlen_local
...@@ -27,10 +27,11 @@ KVCache actually holds the physical kv cache. ...@@ -27,10 +27,11 @@ KVCache actually holds the physical kv cache.
import abc import abc
import logging import logging
from contextlib import nullcontext from contextlib import nullcontext
from typing import List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -66,6 +67,7 @@ class ReqToTokenPool: ...@@ -66,6 +67,7 @@ class ReqToTokenPool:
self.req_to_token = torch.zeros( self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device (size, max_context_len), dtype=torch.int32, device=device
) )
self.free_slots = list(range(size)) self.free_slots = list(range(size))
def write(self, indices, values): def write(self, indices, values):
...@@ -191,7 +193,6 @@ class MHATokenToKVPool(KVCache): ...@@ -191,7 +193,6 @@ class MHATokenToKVPool(KVCache):
start_layer, start_layer,
end_layer, end_layer,
) )
self.head_num = head_num self.head_num = head_num
self.head_dim = head_dim self.head_dim = head_dim
...@@ -392,10 +393,14 @@ class MHATokenToKVPool(KVCache): ...@@ -392,10 +393,14 @@ class MHATokenToKVPool(KVCache):
cache_v: torch.Tensor, cache_v: torch.Tensor,
k_scale: Optional[float] = None, k_scale: Optional[float] = None,
v_scale: Optional[float] = None, v_scale: Optional[float] = None,
layer_id_override: Optional[int] = None,
): ):
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
layer_id = layer.layer_id if layer_id_override is not None:
layer_id = layer_id_override
else:
layer_id = layer.layer_id
if cache_k.dtype != self.dtype: if cache_k.dtype != self.dtype:
if k_scale is not None: if k_scale is not None:
cache_k.div_(k_scale) cache_k.div_(k_scale)
...@@ -431,6 +436,136 @@ class MHATokenToKVPool(KVCache): ...@@ -431,6 +436,136 @@ class MHATokenToKVPool(KVCache):
) )
class SWAKVPool(KVCache):
"""KV cache with separate pools for full and SWA attention layers."""
def __init__(
self,
size: int,
size_swa: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
swa_attention_layer_ids: List[int],
full_attention_layer_ids: List[int],
enable_kvcache_transpose: bool,
device: str,
):
self.size = size
self.size_swa = size_swa
self.dtype = dtype
self.device = device
self.swa_layer_nums = len(swa_attention_layer_ids)
self.full_layer_nums = len(full_attention_layer_ids)
self.page_size = 1
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose
TokenToKVPoolClass = MHATokenToKVPool
self.swa_kv_pool = TokenToKVPoolClass(
size=size_swa,
page_size=self.page_size,
dtype=dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=self.swa_layer_nums,
device=device,
enable_memory_saver=False,
)
self.full_kv_pool = TokenToKVPoolClass(
size=size,
page_size=self.page_size,
dtype=dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=self.full_layer_nums,
device=device,
enable_memory_saver=False,
)
self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
self.layers_mapping[global_layer_id] = (full_attn_layer_id, False)
for swa_layer_id, global_layer_id in enumerate(swa_attention_layer_ids):
self.layers_mapping[global_layer_id] = (swa_layer_id, True)
self.full_to_swa_index_mapping: Optional[torch.Tensor] = None
def get_kv_size_bytes(self):
raise NotImplementedError
def get_contiguous_buf_infos(self):
full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
self.full_kv_pool.get_contiguous_buf_infos()
)
swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = (
self.swa_kv_pool.get_contiguous_buf_infos()
)
kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs
kv_data_lens = full_kv_data_lens + swa_kv_data_lens
kv_item_lens = full_kv_item_lens + swa_kv_item_lens
return kv_data_ptrs, kv_data_lens, kv_item_lens
def get_key_buffer(self, layer_id: int):
layer_id_pool, is_swa = self.layers_mapping[layer_id]
if is_swa:
return self.swa_kv_pool.get_key_buffer(layer_id_pool)
else:
return self.full_kv_pool.get_key_buffer(layer_id_pool)
def get_value_buffer(self, layer_id: int):
layer_id_pool, is_swa = self.layers_mapping[layer_id]
if is_swa:
return self.swa_kv_pool.get_value_buffer(layer_id_pool)
else:
return self.full_kv_pool.get_value_buffer(layer_id_pool)
def get_kv_buffer(self, layer_id: int):
layer_id_pool, is_swa = self.layers_mapping[layer_id]
if is_swa:
return self.swa_kv_pool.get_kv_buffer(layer_id_pool)
else:
return self.full_kv_pool.get_kv_buffer(layer_id_pool)
def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor):
assert self.full_to_swa_index_mapping is not None
return self.full_to_swa_index_mapping[kv_indices].to(torch.int32)
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
):
layer_id = layer.layer_id
layer_id_pool, is_swa = self.layers_mapping[layer_id]
if is_swa:
if self.full_to_swa_index_mapping is not None:
loc = self.translate_loc_from_full_to_swa(loc)
self.swa_kv_pool.set_kv_buffer(
None,
loc,
cache_k,
cache_v,
k_scale,
v_scale,
layer_id_override=layer_id_pool,
)
else:
self.full_kv_pool.set_kv_buffer(
None,
loc,
cache_k,
cache_v,
k_scale,
v_scale,
layer_id_override=layer_id_pool,
)
@triton.jit @triton.jit
def set_mla_kv_buffer_kernel( def set_mla_kv_buffer_kernel(
kv_buffer_ptr, kv_buffer_ptr,
......
...@@ -74,6 +74,7 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -74,6 +74,7 @@ from sglang.srt.managers.schedule_batch import (
from sglang.srt.mem_cache.allocator import ( from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator, BaseTokenToKVPoolAllocator,
PagedTokenToKVPoolAllocator, PagedTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
TokenToKVPoolAllocator, TokenToKVPoolAllocator,
) )
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
...@@ -81,6 +82,7 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -81,6 +82,7 @@ from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool, MHATokenToKVPool,
MLATokenToKVPool, MLATokenToKVPool,
ReqToTokenPool, ReqToTokenPool,
SWAKVPool,
) )
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
...@@ -185,6 +187,7 @@ class ModelRunner: ...@@ -185,6 +187,7 @@ class ModelRunner:
self.page_size = server_args.page_size self.page_size = server_args.page_size
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.is_hybrid = model_config.is_hybrid
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
self.attention_chunk_size = model_config.attention_chunk_size self.attention_chunk_size = model_config.attention_chunk_size
...@@ -437,6 +440,10 @@ class ModelRunner: ...@@ -437,6 +440,10 @@ class ModelRunner:
if self.model_config.context_len > 8192: if self.model_config.context_len > 8192:
self.mem_fraction_static *= 0.85 self.mem_fraction_static *= 0.85
if self.is_hybrid and not server_args.disable_radix_cache:
logger.info("Automatically disable radix cache for hybrid cache.")
server_args.disable_radix_cache = True
def init_torch_distributed(self): def init_torch_distributed(self):
logger.info("Init torch distributed begin.") logger.info("Init torch distributed begin.")
...@@ -852,6 +859,40 @@ class ModelRunner: ...@@ -852,6 +859,40 @@ class ModelRunner:
max_num_token = int(rest_memory * (1 << 30) // cell_size) max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token return max_num_token
def set_num_token_hybrid(self):
if (
"Llama4ForConditionalGeneration"
in self.model_config.hf_config.architectures
):
temp_ratio = (
(1 - self.is_hybrid)
+ self.is_hybrid
* self.attention_chunk_size
/ self.model_config.context_len
)
self.swa_max_total_num_tokens = (
4 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
)
self.full_max_total_num_tokens = (
4 * self.max_total_num_tokens
- 12 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
)
self.swa_max_total_num_tokens = int(
self.swa_max_total_num_tokens
// self.server_args.page_size
* self.server_args.page_size
)
self.full_max_total_num_tokens = int(
self.full_max_total_num_tokens
// self.server_args.page_size
* self.server_args.page_size
)
self.max_total_num_tokens = self.full_max_total_num_tokens
else:
raise ValueError(
f"Unsupported model for hybrid cache: {self.model_config.hf_config.architectures}."
)
def init_memory_pool( def init_memory_pool(
self, self,
total_gpu_memory: int, total_gpu_memory: int,
...@@ -929,6 +970,10 @@ class ModelRunner: ...@@ -929,6 +970,10 @@ class ModelRunner:
* self.server_args.page_size * self.server_args.page_size
) )
# create token size for hybrid cache
if self.is_hybrid:
self.set_num_token_hybrid()
if self.max_total_num_tokens <= 0: if self.max_total_num_tokens <= 0:
raise RuntimeError( raise RuntimeError(
"Not enough memory. Please try to increase --mem-fraction-static." "Not enough memory. Please try to increase --mem-fraction-static."
...@@ -991,27 +1036,53 @@ class ModelRunner: ...@@ -991,27 +1036,53 @@ class ModelRunner:
end_layer=self.end_layer, end_layer=self.end_layer,
) )
else: else:
self.token_to_kv_pool = MHATokenToKVPool( if self.is_hybrid:
self.max_total_num_tokens, self.token_to_kv_pool = SWAKVPool(
page_size=self.page_size, size=self.full_max_total_num_tokens,
dtype=self.kv_cache_dtype, size_swa=self.swa_max_total_num_tokens,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), dtype=self.kv_cache_dtype,
head_dim=self.model_config.head_dim, head_num=self.model_config.get_num_kv_heads(
layer_num=self.num_effective_layers, get_attention_tp_size()
device=self.device, ),
enable_memory_saver=self.server_args.enable_memory_saver, head_dim=self.model_config.head_dim,
start_layer=self.start_layer, swa_attention_layer_ids=self.model_config.swa_attention_layer_ids,
end_layer=self.end_layer, full_attention_layer_ids=self.model_config.full_attention_layer_ids,
) enable_kvcache_transpose=False,
device=self.device,
if self.token_to_kv_pool_allocator is None: )
if self.page_size == 1: else:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator( self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(
get_attention_tp_size()
),
head_dim=self.model_config.head_dim,
layer_num=self.num_effective_layers,
device=self.device, device=self.device,
kvcache=self.token_to_kv_pool, enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
) )
if self.token_to_kv_pool_allocator is None:
if self.page_size == 1:
if self.is_hybrid:
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
self.full_max_total_num_tokens,
self.swa_max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
)
else:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
)
else: else:
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator( self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
self.max_total_num_tokens, self.max_total_num_tokens,
......
...@@ -61,6 +61,7 @@ class ServerArgs: ...@@ -61,6 +61,7 @@ class ServerArgs:
is_embedding: bool = False is_embedding: bool = False
enable_multimodal: Optional[bool] = None enable_multimodal: Optional[bool] = None
revision: Optional[str] = None revision: Optional[str] = None
hybrid_kvcache_ratio: Optional[float] = None
impl: str = "auto" impl: str = "auto"
# Port for the HTTP server # Port for the HTTP server
...@@ -817,6 +818,18 @@ class ServerArgs: ...@@ -817,6 +818,18 @@ class ServerArgs:
default=ServerArgs.page_size, default=ServerArgs.page_size,
help="The number of tokens in a page.", help="The number of tokens in a page.",
) )
parser.add_argument(
"--hybrid-kvcache-ratio",
nargs="?",
const=0.5,
type=float,
default=ServerArgs.hybrid_kvcache_ratio,
help=(
"Mix ratio in [0,1] between uniform and hybrid kv buffers "
"(0.0 = pure uniform: swa_size / full_size = 1)"
"(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
),
)
# Other runtime options # Other runtime options
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment