Unverified Commit eb6c2c16 authored by tarinkk's avatar tarinkk Committed by GitHub
Browse files

Hybrid kv cache for LLaMA4 (#6563)


Co-authored-by: default avatarCheng Wan <54331508+ch-wan@users.noreply.github.com>
Co-authored-by: default avatartarinkk <rt572@physics.rutger.edu>
Co-authored-by: default avatartarinkk <rt572@rutgers.physics.edu>
Co-authored-by: default avatarHanming Lu <69857889+hanming-lu@users.noreply.github.com>
parent 357921aa
......@@ -16,10 +16,11 @@ python3 -m sglang.launch_server --model-path meta-llama/Llama-4-Scout-17B-16E-In
### Configuration Tips
- **OOM Mitigation**: Adjust `--context-length` to avoid a GPU out-of-memory issue. For the Scout model, we recommend setting this value up to 1M on 8\*H100 and up to 2.5M on 8\*H200. For the Maverick model, we don't need to set context length on 8\*H200.
- **OOM Mitigation**: Adjust `--context-length` to avoid a GPU out-of-memory issue. For the Scout model, we recommend setting this value up to 1M on 8\*H100 and up to 2.5M on 8\*H200. For the Maverick model, we don't need to set context length on 8\*H200. When hybrid kv cache is enabled, `--context-length` can be set up to 5M on 8\*H100 and up to 10M on 8\*H200 for the Scout model.
- **Chat Template**: Add `--chat-template llama-4` for chat completion tasks.
- **Enable Multi-Modal**: Add `--enable-multimodal` for multi-modal capabilities.
- **Enable Hybrid-KVCache**: Add `--hybrid-kvcache-ratio` for hybrid kv cache. Details can be seen in [this PR](https://github.com/sgl-project/sglang/pull/6563)
## Benchmarking Results
......
......@@ -59,6 +59,7 @@ class ModelConfig:
quantization: Optional[str] = None,
override_config_file: Optional[str] = None,
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None:
......@@ -86,6 +87,18 @@ class ModelConfig:
self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None
)
self.is_hybrid = is_hybrid_model(
self.hf_config.architectures,
hybrid_kvcache_ratio=hybrid_kvcache_ratio,
context_length=context_length,
attention_chunk_size=self.attention_chunk_size,
)
if self.is_hybrid is not None:
self.swa_attention_layer_ids, self.full_attention_layer_ids = (
get_hybrid_layer_ids(
self.hf_config.architectures, self.hf_text_config.num_hidden_layers
)
)
if enable_multimodal is None:
mm_disabled_models = [
......@@ -264,6 +277,7 @@ class ModelConfig:
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
impl=server_args.impl,
**kwargs,
)
......@@ -633,3 +647,36 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def is_hybrid_model(
model_architectures: List[str],
hybrid_kvcache_ratio: Optional[float],
context_length: Optional[int],
attention_chunk_size: Optional[int],
):
if hybrid_kvcache_ratio is None:
return None
elif (
hybrid_kvcache_ratio > 0
and model_architectures[0] == "Llama4ForConditionalGeneration"
and context_length > attention_chunk_size
):
return hybrid_kvcache_ratio
else:
return None
def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
if "Llama4ForConditionalGeneration" in model_architectures:
swa_attention_layer_ids = [
i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
]
full_attention_layer_ids = [
i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
]
else:
raise ValueError(
"get_hybrid_layer_ids is only implemented for Llama4ForConditionalGeneration"
)
return swa_attention_layer_ids, full_attention_layer_ids
......@@ -433,9 +433,7 @@ class DecodePreallocQueue:
else 0
)
available_size = self.token_to_kv_pool_allocator.available_size()
allocatable_tokens = available_size - max(
allocatable_tokens = self.token_to_kv_pool_allocator.available_size() - max(
# preserve some space for future decode
self.num_reserved_decode_tokens
* (
......
......@@ -9,6 +9,7 @@ import torch
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import SWAKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
......@@ -320,6 +321,11 @@ class FlashAttentionBackend(AttentionBackend):
self.page_size = model_runner.page_size
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
self.skip_prefill = skip_prefill
self.is_hybrid = model_runner.is_hybrid
if self.is_hybrid:
self.full_to_swa_index_mapping = (
model_runner.token_to_kv_pool.full_to_swa_index_mapping
)
self.topk = model_runner.server_args.speculative_eagle_topk or 0
self.speculative_num_steps = speculative_num_steps
self.speculative_num_draft_tokens = (
......@@ -428,7 +434,7 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
# TODO: we need to test this part for llama 4 eagle case
self._init_local_attn_metadata(metadata, device)
self._init_local_attn_metadata(forward_batch, metadata, device)
elif forward_batch.forward_mode.is_target_verify():
if self.topk <= 1:
metadata.cache_seqlens_int32 = (
......@@ -456,7 +462,7 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
self._init_local_attn_metadata(metadata, device)
self._init_local_attn_metadata(forward_batch, metadata, device)
else:
metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
metadata.max_seq_len_q = self.speculative_num_draft_tokens
......@@ -575,7 +581,7 @@ class FlashAttentionBackend(AttentionBackend):
# Setup local attention if enabled
if forward_batch.forward_mode == ForwardMode.EXTEND:
self._init_local_attn_metadata(metadata, device)
self._init_local_attn_metadata(forward_batch, metadata, device)
# Encoder metadata for cross attention
if forward_batch.encoder_lens is not None:
......@@ -1588,7 +1594,7 @@ class FlashAttentionBackend(AttentionBackend):
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: torch.Tensor = None,
out_cache_loc: Optional[torch.Tensor] = None,
):
"""Initialize forward metadata for replaying CUDA graph."""
seq_lens = seq_lens[:bs]
......@@ -1673,7 +1679,10 @@ class FlashAttentionBackend(AttentionBackend):
self.page_size,
)
self._update_local_attn_metadata_for_replay(metadata, bs)
self._update_local_attn_metadata_for_replay(
metadata,
bs,
)
elif forward_mode.is_target_verify():
if self.topk <= 1:
metadata = self.target_verify_metadata[bs]
......@@ -1829,7 +1838,9 @@ class FlashAttentionBackend(AttentionBackend):
"""Get the fill value for sequence length in CUDA graph."""
return 1
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
def _init_local_attn_metadata(
self, forwardbatch: ForwardBatch, metadata: FlashAttentionMetadata, device
):
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
if self.attention_chunk_size is None:
metadata.local_attn_metadata = None
......@@ -1837,6 +1848,11 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_q = metadata.cu_seqlens_q
cache_seqlens_int32 = metadata.cache_seqlens_int32
if self.is_hybrid:
page_table = self.full_to_swa_index_mapping[metadata.page_table].to(
torch.int32
)
else:
page_table = metadata.page_table
if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
metadata.local_attn_metadata = None
......@@ -1923,7 +1939,9 @@ class FlashAttentionBackend(AttentionBackend):
)
def _update_local_attn_metadata_for_replay(
self, metadata: FlashAttentionMetadata, bs: int
self,
metadata: FlashAttentionMetadata,
bs: int,
):
"""Update preallocated local attention metadata in-place before CUDA graph replay."""
if self.attention_chunk_size is None:
......@@ -1954,6 +1972,11 @@ class FlashAttentionBackend(AttentionBackend):
# Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
# beyond the actual sequence length, leading to incorrect attention calculations
max_seq_len = int(seqlens.max().item())
if self.is_hybrid:
sliced_page_table = self.full_to_swa_index_mapping[
metadata.page_table[:bs, :max_seq_len]
].to(torch.int32)
else:
sliced_page_table = metadata.page_table[:bs, :max_seq_len]
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
......
......@@ -56,7 +56,7 @@ from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.metrics.collector import TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
......@@ -485,6 +485,9 @@ class Req:
# for corss-endoder model
self.token_type_ids = token_type_ids
# The length of KV that have been removed in local attention chunked prefill
self.evicted_seqlen_local = 0
# Sampling info
if isinstance(sampling_params.custom_params, dict):
sampling_params = copy.copy(sampling_params)
......@@ -1191,6 +1194,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.req_to_token_pool.write(
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
)
if isinstance(self.tree_cache, SWAChunkCache):
self.tree_cache.evict(
req, pre_len, self.model_config.attention_chunk_size
)
# If input_embeds are available, store them
if req.input_embeds is not None:
......@@ -1383,7 +1390,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
* buf_multiplier
* self.token_to_kv_pool_allocator.page_size
)
if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
return True
......@@ -1564,6 +1570,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens.add_(1)
self.seq_lens_sum += bs
# free memory
if isinstance(self.tree_cache, SWAChunkCache):
for req in self.reqs:
self.tree_cache.evict(
req, req.seqlen - 1, self.model_config.attention_chunk_size
)
# Allocate memory
if self.token_to_kv_pool_allocator.page_size == 1:
self.out_cache_loc = self.alloc_token_slots(bs)
......@@ -1798,7 +1811,6 @@ class ModelWorkerBatch:
seq_lens: torch.Tensor
# The indices of output tokens in the token_to_kv_pool_allocator
out_cache_loc: torch.Tensor
# The sequence length tensor on CPU
seq_lens_cpu: Optional[torch.Tensor]
seq_lens_sum: int
......
......@@ -126,7 +126,8 @@ from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
......@@ -570,7 +571,11 @@ class Scheduler(
server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache
):
self.tree_cache = ChunkCache(
if self.model_config.is_hybrid:
ChunkCacheClass = SWAChunkCache
else:
ChunkCacheClass = ChunkCache
self.tree_cache = ChunkCacheClass(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
......@@ -1283,9 +1288,8 @@ class Scheduler(
self.last_input_throughput = self.last_prefill_tokens / gap_latency
self.last_prefill_tokens = adder.log_input_tokens
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
self.tree_cache.evictable_size()
)
num_new_seq = len(can_run_list)
......@@ -1294,7 +1298,7 @@ class Scheduler(
f"#new-seq: {num_new_seq}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"{usage_msg}"
)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
......@@ -1337,9 +1341,8 @@ class Scheduler(
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
num_running_reqs = len(batch.reqs)
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
self.tree_cache.evictable_size()
)
if RECORD_STEP_TIME:
......@@ -1347,12 +1350,7 @@ class Scheduler(
gap_latency / self.server_args.decode_log_interval
)
msg = (
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
)
msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}"
if self.spec_algorithm.is_none():
spec_accept_length = 0
......@@ -1390,10 +1388,11 @@ class Scheduler(
self._publish_kv_events()
def check_memory(self):
available_size = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
available_token_size = self.token_to_kv_pool_allocator.full_available_size()
else:
available_token_size = self.token_to_kv_pool_allocator.available_size()
available_size = available_token_size + self.tree_cache.evictable_size()
protected_size = self.tree_cache.protected_size()
memory_leak = available_size != (
self.max_total_num_tokens
......@@ -1404,7 +1403,7 @@ class Scheduler(
msg = (
"token_to_kv_pool_allocator memory leak detected! "
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{available_token_size=}\n"
f"{self.tree_cache.evictable_size()=}\n"
)
raise ValueError(msg)
......
......@@ -20,12 +20,14 @@ Page-aligned memory pool.
"""
import abc
import weakref
from typing import TYPE_CHECKING
import torch
import triton
import triton.language as tl
from sglang.srt.mem_cache.memory_pool import SWAKVPool
from sglang.srt.utils import get_bool_env_var, next_power_of_2
if TYPE_CHECKING:
......@@ -55,6 +57,11 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
def debug_print(self) -> str:
return ""
def log_usage(self, evictable_size: int = 0):
num_used = self.size - (self.available_size() + evictable_size)
msg = f"#token: {num_used}, token usage: {num_used / self.size:.2f}, "
return msg, num_used
def available_size(self):
return len(self.free_pages) * self.page_size
......@@ -146,6 +153,128 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
"""Allocator for SWA hybrid KV cache."""
def __init__(
self,
size: int,
size_swa: int,
dtype: torch.dtype,
device: str,
kvcache: SWAKVPool,
):
super().__init__(size, 1, dtype, device, kvcache)
assert isinstance(kvcache, SWAKVPool)
self._size_full = size
self._size_swa = size_swa
self.full_attn_allocator = TokenToKVPoolAllocator(
size,
dtype,
device,
kvcache.full_kv_pool,
)
self.swa_attn_allocator = TokenToKVPoolAllocator(
size_swa,
dtype,
device,
kvcache.swa_kv_pool,
)
self.full_to_swa_index_mapping = torch.empty(
size + size_swa + 1,
dtype=torch.int64,
device=device,
)
self.clear()
self._kvcache.full_to_swa_index_mapping = self.full_to_swa_index_mapping
def available_size(self):
return min(self.full_available_size(), self.swa_available_size())
def full_available_size(self):
return self.full_attn_allocator.available_size()
def swa_available_size(self):
return self.swa_attn_allocator.available_size()
@property
def size_full(self):
return self._size_full
@property
def size_swa(self):
return self._size_swa
def debug_print(self) -> str:
msg = ""
msg += f"#swa-available-size: {self.swa_attn_allocator.available_size()}, "
msg += (
f"#full-attn-available-size: {self.full_attn_allocator.available_size()}, "
)
return msg
def log_usage(self, swa_evictable_size: int = 0, full_evictable_size: int = 0):
used_full = self.size_full - (self.full_available_size() + full_evictable_size)
used_swa = self.size_swa - (self.swa_available_size() + swa_evictable_size)
msg = (
f"#token: full={used_full}, swa={used_swa}, "
f"token usage: full={used_full / self.size_full:.2f}, "
f"swa={used_swa / self.size_swa:.2f}, "
)
return msg, used_full
def get_kvcache(self):
return self._kvcache
def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor):
assert self.full_to_swa_index_mapping is not None
return self.full_to_swa_index_mapping[kv_indices].to(torch.int32)
def alloc(self, need_size: int):
if need_size > self.full_attn_allocator.available_size():
return None
if need_size > self.swa_attn_allocator.available_size():
return None
alloc_full_indices = self.full_attn_allocator.alloc(need_size)
alloc_swa_indices = self.swa_attn_allocator.alloc(need_size)
self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices
return alloc_full_indices
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
self.full_attn_allocator.free(free_index)
self.free_swa(free_index)
else:
self.free_group.append(free_index)
assert (
self.full_attn_allocator.available_size() <= self.full_attn_allocator.size
)
assert self.swa_attn_allocator.available_size() <= self.swa_attn_allocator.size
def free_swa(self, free_index: torch.Tensor):
swa_indices = self.full_to_swa_index_mapping[free_index]
swa_indices = swa_indices[swa_indices > 0]
self.swa_attn_allocator.free(swa_indices)
self.full_to_swa_index_mapping[free_index] = 0
def backup_state(self):
raise NotImplementedError
def restore_state(self, state):
raise NotImplementedError
def clear(self):
self.swa_attn_allocator.clear()
self.full_attn_allocator.clear()
self.full_to_swa_index_mapping.fill_(0)
self.is_in_free_group = False
self.free_group = []
@triton.jit
def alloc_extend_kernel(
pre_lens_ptr,
......
......@@ -2,11 +2,14 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
import torch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
......@@ -63,3 +66,32 @@ class ChunkCache(BasePrefixCache):
def pretty_print(self):
return ""
class SWAChunkCache(ChunkCache):
"""ChunkCache with support for hybrid KV cache operations."""
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: SWATokenToKVPoolAllocator,
page_size: int,
):
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size)
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
def evict(
self,
req: Req,
prelen: int,
attention_chunk_size: int,
):
if prelen >= req.evicted_seqlen_local + attention_chunk_size:
new_evicted_seqlen_local = attention_chunk_size * (
prelen // attention_chunk_size
)
free_slots = self.req_to_token_pool.req_to_token[
req.req_pool_idx, req.evicted_seqlen_local : new_evicted_seqlen_local
]
self.token_to_kv_pool_allocator.free_swa(free_slots)
req.evicted_seqlen_local = new_evicted_seqlen_local
......@@ -27,10 +27,11 @@ KVCache actually holds the physical kv cache.
import abc
import logging
from contextlib import nullcontext
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
import triton
import triton.language as tl
......@@ -66,6 +67,7 @@ class ReqToTokenPool:
self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device
)
self.free_slots = list(range(size))
def write(self, indices, values):
......@@ -191,7 +193,6 @@ class MHATokenToKVPool(KVCache):
start_layer,
end_layer,
)
self.head_num = head_num
self.head_dim = head_dim
......@@ -392,9 +393,13 @@ class MHATokenToKVPool(KVCache):
cache_v: torch.Tensor,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
layer_id_override: Optional[int] = None,
):
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
if layer_id_override is not None:
layer_id = layer_id_override
else:
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
if k_scale is not None:
......@@ -431,6 +436,136 @@ class MHATokenToKVPool(KVCache):
)
class SWAKVPool(KVCache):
"""KV cache with separate pools for full and SWA attention layers."""
def __init__(
self,
size: int,
size_swa: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
swa_attention_layer_ids: List[int],
full_attention_layer_ids: List[int],
enable_kvcache_transpose: bool,
device: str,
):
self.size = size
self.size_swa = size_swa
self.dtype = dtype
self.device = device
self.swa_layer_nums = len(swa_attention_layer_ids)
self.full_layer_nums = len(full_attention_layer_ids)
self.page_size = 1
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose
TokenToKVPoolClass = MHATokenToKVPool
self.swa_kv_pool = TokenToKVPoolClass(
size=size_swa,
page_size=self.page_size,
dtype=dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=self.swa_layer_nums,
device=device,
enable_memory_saver=False,
)
self.full_kv_pool = TokenToKVPoolClass(
size=size,
page_size=self.page_size,
dtype=dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=self.full_layer_nums,
device=device,
enable_memory_saver=False,
)
self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
self.layers_mapping[global_layer_id] = (full_attn_layer_id, False)
for swa_layer_id, global_layer_id in enumerate(swa_attention_layer_ids):
self.layers_mapping[global_layer_id] = (swa_layer_id, True)
self.full_to_swa_index_mapping: Optional[torch.Tensor] = None
def get_kv_size_bytes(self):
raise NotImplementedError
def get_contiguous_buf_infos(self):
full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
self.full_kv_pool.get_contiguous_buf_infos()
)
swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = (
self.swa_kv_pool.get_contiguous_buf_infos()
)
kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs
kv_data_lens = full_kv_data_lens + swa_kv_data_lens
kv_item_lens = full_kv_item_lens + swa_kv_item_lens
return kv_data_ptrs, kv_data_lens, kv_item_lens
def get_key_buffer(self, layer_id: int):
layer_id_pool, is_swa = self.layers_mapping[layer_id]
if is_swa:
return self.swa_kv_pool.get_key_buffer(layer_id_pool)
else:
return self.full_kv_pool.get_key_buffer(layer_id_pool)
def get_value_buffer(self, layer_id: int):
layer_id_pool, is_swa = self.layers_mapping[layer_id]
if is_swa:
return self.swa_kv_pool.get_value_buffer(layer_id_pool)
else:
return self.full_kv_pool.get_value_buffer(layer_id_pool)
def get_kv_buffer(self, layer_id: int):
layer_id_pool, is_swa = self.layers_mapping[layer_id]
if is_swa:
return self.swa_kv_pool.get_kv_buffer(layer_id_pool)
else:
return self.full_kv_pool.get_kv_buffer(layer_id_pool)
def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor):
assert self.full_to_swa_index_mapping is not None
return self.full_to_swa_index_mapping[kv_indices].to(torch.int32)
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
):
layer_id = layer.layer_id
layer_id_pool, is_swa = self.layers_mapping[layer_id]
if is_swa:
if self.full_to_swa_index_mapping is not None:
loc = self.translate_loc_from_full_to_swa(loc)
self.swa_kv_pool.set_kv_buffer(
None,
loc,
cache_k,
cache_v,
k_scale,
v_scale,
layer_id_override=layer_id_pool,
)
else:
self.full_kv_pool.set_kv_buffer(
None,
loc,
cache_k,
cache_v,
k_scale,
v_scale,
layer_id_override=layer_id_pool,
)
@triton.jit
def set_mla_kv_buffer_kernel(
kv_buffer_ptr,
......
......@@ -74,6 +74,7 @@ from sglang.srt.managers.schedule_batch import (
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
PagedTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.memory_pool import (
......@@ -81,6 +82,7 @@ from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
SWAKVPool,
)
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
......@@ -185,6 +187,7 @@ class ModelRunner:
self.page_size = server_args.page_size
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.is_hybrid = model_config.is_hybrid
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
self.attention_chunk_size = model_config.attention_chunk_size
......@@ -437,6 +440,10 @@ class ModelRunner:
if self.model_config.context_len > 8192:
self.mem_fraction_static *= 0.85
if self.is_hybrid and not server_args.disable_radix_cache:
logger.info("Automatically disable radix cache for hybrid cache.")
server_args.disable_radix_cache = True
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
......@@ -852,6 +859,40 @@ class ModelRunner:
max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token
def set_num_token_hybrid(self):
if (
"Llama4ForConditionalGeneration"
in self.model_config.hf_config.architectures
):
temp_ratio = (
(1 - self.is_hybrid)
+ self.is_hybrid
* self.attention_chunk_size
/ self.model_config.context_len
)
self.swa_max_total_num_tokens = (
4 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
)
self.full_max_total_num_tokens = (
4 * self.max_total_num_tokens
- 12 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
)
self.swa_max_total_num_tokens = int(
self.swa_max_total_num_tokens
// self.server_args.page_size
* self.server_args.page_size
)
self.full_max_total_num_tokens = int(
self.full_max_total_num_tokens
// self.server_args.page_size
* self.server_args.page_size
)
self.max_total_num_tokens = self.full_max_total_num_tokens
else:
raise ValueError(
f"Unsupported model for hybrid cache: {self.model_config.hf_config.architectures}."
)
def init_memory_pool(
self,
total_gpu_memory: int,
......@@ -929,6 +970,10 @@ class ModelRunner:
* self.server_args.page_size
)
# create token size for hybrid cache
if self.is_hybrid:
self.set_num_token_hybrid()
if self.max_total_num_tokens <= 0:
raise RuntimeError(
"Not enough memory. Please try to increase --mem-fraction-static."
......@@ -990,12 +1035,29 @@ class ModelRunner:
start_layer=self.start_layer,
end_layer=self.end_layer,
)
else:
if self.is_hybrid:
self.token_to_kv_pool = SWAKVPool(
size=self.full_max_total_num_tokens,
size_swa=self.swa_max_total_num_tokens,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(
get_attention_tp_size()
),
head_dim=self.model_config.head_dim,
swa_attention_layer_ids=self.model_config.swa_attention_layer_ids,
full_attention_layer_ids=self.model_config.full_attention_layer_ids,
enable_kvcache_transpose=False,
device=self.device,
)
else:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_num=self.model_config.get_num_kv_heads(
get_attention_tp_size()
),
head_dim=self.model_config.head_dim,
layer_num=self.num_effective_layers,
device=self.device,
......@@ -1006,6 +1068,15 @@ class ModelRunner:
if self.token_to_kv_pool_allocator is None:
if self.page_size == 1:
if self.is_hybrid:
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
self.full_max_total_num_tokens,
self.swa_max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
)
else:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
......
......@@ -61,6 +61,7 @@ class ServerArgs:
is_embedding: bool = False
enable_multimodal: Optional[bool] = None
revision: Optional[str] = None
hybrid_kvcache_ratio: Optional[float] = None
impl: str = "auto"
# Port for the HTTP server
......@@ -817,6 +818,18 @@ class ServerArgs:
default=ServerArgs.page_size,
help="The number of tokens in a page.",
)
parser.add_argument(
"--hybrid-kvcache-ratio",
nargs="?",
const=0.5,
type=float,
default=ServerArgs.hybrid_kvcache_ratio,
help=(
"Mix ratio in [0,1] between uniform and hybrid kv buffers "
"(0.0 = pure uniform: swa_size / full_size = 1)"
"(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
),
)
# Other runtime options
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment