Unverified Commit 5206e5e2 authored by Harry Huang's avatar Harry Huang Committed by GitHub
Browse files

[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)


Signed-off-by: default avatarhuanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
Co-authored-by: default avatarChen Zhang <zhangch99@outlook.com>
parent fec9da0a
......@@ -35,6 +35,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01LinearAttention
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -1006,3 +1008,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
tp_size=parallel_config.tensor_parallel_size,
head_dim=hf_config.head_dim,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.linear_attention_state_copy_func()
......@@ -2128,3 +2128,7 @@ class NemotronH_Nano_VL_V2(
temp_vllm_config = copy.deepcopy(vllm_config)
temp_vllm_config.model_config.hf_config = text_config
return NemotronHForCausalLM.get_mamba_state_dtype_from_config(temp_vllm_config)
@classmethod
def get_mamba_state_copy_func(cls):
return NemotronHForCausalLM.get_mamba_state_copy_func()
......@@ -45,6 +45,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -809,6 +811,10 @@ class NemotronHForCausalLM(
conv_kernel=hf_config.conv_kernel,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
......
......@@ -27,6 +27,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -899,6 +901,10 @@ class Plamo2ForCausalLM(
conv_kernel=hf_config.mamba_d_conv,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def compute_logits(
self,
hidden_states: torch.Tensor,
......
......@@ -48,6 +48,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -1205,9 +1207,11 @@ class Qwen3NextForCausalLM(
cache_config = vllm_config.cache_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, (
"Qwen3Next currently does not support prefix caching"
)
if cache_config.mamba_cache_mode == "all":
raise NotImplementedError(
"Qwen3Next currently does not support 'all' prefix caching, "
"please use '--mamba-cache-mode=align' instead"
)
self.quant_config = vllm_config.quant_config
super().__init__()
......@@ -1278,6 +1282,10 @@ class Qwen3NextForCausalLM(
num_spec,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func()
def compute_logits(
self,
hidden_states: torch.Tensor,
......
......@@ -234,9 +234,11 @@ class Qwen3NextMTP(nn.Module, QwenNextMixtureOfExperts):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
cache_config = vllm_config.cache_config
assert not cache_config.enable_prefix_caching, (
"Qwen3NextMTP currently does not support prefix caching"
)
if cache_config.mamba_cache_mode == "all":
raise NotImplementedError(
"Qwen3NextMTP currently does not support 'all' prefix caching, "
"please use '--mamba-cache-mode=align' instead"
)
self.quant_config = vllm_config.quant_config
......
......@@ -32,6 +32,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
......@@ -891,6 +893,10 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
conv_kernel=hf_config.mamba_d_conv,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
"""Initialize the Zamba2 model for causal language modeling.
......
......@@ -16,6 +16,7 @@ from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
compute_causal_conv1d_metadata,
mamba_get_block_table_tensor,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
......@@ -158,6 +159,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
query_start_loc_cpu = m.query_start_loc_cpu
context_lens_tensor = m.compute_num_computed_tokens()
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
block_table_tensor = mamba_get_block_table_tensor(
m.block_table_tensor,
m.seq_lens,
self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode,
)
spec_sequence_masks_cpu: torch.Tensor | None = None
if (
......@@ -189,7 +196,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
spec_token_indx = None
non_spec_token_indx = None
spec_state_indices_tensor = None
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
non_spec_state_indices_tensor = block_table_tensor[:, 0]
spec_query_start_loc = None
non_spec_query_start_loc = query_start_loc
non_spec_query_start_loc_cpu = query_start_loc_cpu
......@@ -221,7 +228,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_token_indx = torch.empty(
0, dtype=torch.int32, device=query_start_loc.device
)
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
spec_state_indices_tensor = block_table_tensor[:, : self.num_spec + 1]
non_spec_state_indices_tensor = None
spec_query_start_loc = query_start_loc
non_spec_query_start_loc = None
......@@ -235,10 +242,10 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_token_indx = index[:num_non_spec_tokens]
spec_token_indx = index[num_non_spec_tokens:]
spec_state_indices_tensor = m.block_table_tensor[
spec_state_indices_tensor = block_table_tensor[
spec_sequence_masks, : self.num_spec + 1
]
non_spec_state_indices_tensor = m.block_table_tensor[
non_spec_state_indices_tensor = block_table_tensor[
~spec_sequence_masks, 0
]
......
......@@ -11,7 +11,10 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import split_decodes_and_prefills
from vllm.v1.attention.backends.utils import (
mamba_get_block_table_tensor,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
......@@ -61,7 +64,12 @@ class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMet
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
state_indices_tensor = mamba_get_block_table_tensor(
common_attn_metadata.block_table_tensor,
common_attn_metadata.seq_lens,
self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode,
)[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
......
......@@ -18,6 +18,7 @@ from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
compute_causal_conv1d_metadata,
mamba_get_block_table_tensor,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
......@@ -41,11 +42,15 @@ class BaseMambaAttentionMetadata:
state_indices_tensor: torch.Tensor
# The following tensors are only used for prefix caching and are None if disabled
# The following tensors are only used for prefix caching in all mode and
# are None if disabled
block_idx_last_scheduled_token: torch.Tensor | None
block_idx_first_scheduled_token_p: torch.Tensor | None
block_idx_last_computed_token: torch.Tensor | None
# The following tensor is only used for prefix caching in align mode
seq_lens: torch.Tensor
# The following attributes are for triton implementation of causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
......@@ -78,7 +83,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
self.compilation_config.max_cudagraph_capture_size,
)
if self.vllm_config.cache_config.enable_prefix_caching:
if self.vllm_config.cache_config.mamba_cache_mode == "all":
self.state_indices_tensor = torch.empty(
(
self.decode_cudagraph_max_bs,
......@@ -198,7 +203,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if self.vllm_config.cache_config.enable_prefix_caching:
if self.vllm_config.cache_config.mamba_cache_mode == "all":
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
# Return a tensor of shape (#requests, #max blocks)
......@@ -214,7 +219,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
)
else:
# Always return just a single block per each request:
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
state_indices_tensor = mamba_get_block_table_tensor(
common_attn_metadata.block_table_tensor,
common_attn_metadata.seq_lens,
self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode,
)[:, 0]
if num_prefills > 0:
if num_computed_tokens is None:
......@@ -239,7 +249,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
)
)
if self.vllm_config.cache_config.enable_prefix_caching:
if self.vllm_config.cache_config.mamba_cache_mode == "all":
assert num_computed_tokens is not None
num_computed_tokens_p = num_computed_tokens[
num_reqs - num_prefills : num_reqs
......@@ -258,7 +268,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
if self.vllm_config.cache_config.enable_prefix_caching:
if self.vllm_config.cache_config.mamba_cache_mode == "all":
self.block_idx_last_scheduled_token[:num_decodes].copy_(
block_idx_last_scheduled_token, non_blocking=True
)
......@@ -286,6 +296,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
block_idx_last_computed_token=block_idx_last_computed_token,
num_computed_tokens_p=num_computed_tokens_p,
num_reqs=num_reqs,
seq_lens=common_attn_metadata.seq_lens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
......@@ -298,8 +309,16 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
slot_mapping: torch.Tensor,
) -> M:
new_metadata = copy.copy(metadata)
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
state_indices_t = mamba_get_block_table_tensor(
blk_table,
metadata.seq_lens,
self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode,
)
if self.vllm_config.cache_config.mamba_cache_mode in ("none", "align"):
# Only needs the block that saves the running state
state_indices_t = state_indices_t[:, 0]
num_reqs = blk_table.shape[0]
# For CUDA graphs, copy to persistent buffer
......
......@@ -17,6 +17,7 @@ from typing_extensions import runtime_checkable
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
......@@ -854,3 +855,40 @@ def extend_all_queries_by_1(
slot_mapping=new_slot_mapping,
)
return new_cad
def mamba_get_block_table_tensor(
block_table: torch.Tensor,
seq_lens: torch.Tensor,
kv_cache_spec: KVCacheSpec,
mamba_cache_mode: str,
) -> torch.Tensor:
"""
Get the block table tensor for mamba kernels from the input
common_attn_metadata.block_table_tensor given different mamba cache modes.
- "all": input (#requests, cdiv(max_model_len, block_size));
output (#requests, cdiv(max_model_len, block_size)).
- "none": input (#requests, 1 + num_speculative_blocks);
output (#requests, 1 + num_speculative_blocks).
- "align": input (#requests, cdiv(max_model_len, block_size));
output (#requests, 1 + num_speculative_blocks), which are the last
1 + num_speculative_blocks of each request.
"""
if mamba_cache_mode in ("all", "none"):
return block_table
else:
assert isinstance(kv_cache_spec, MambaSpec)
# NOTE: For 0-length requests in CUDA graph, use a start_index of 0
# to handle the invalid block table.
start_indices = torch.clamp(
(seq_lens - 1) // kv_cache_spec.block_size,
min=0,
)
offsets = torch.arange(
1 + kv_cache_spec.num_speculative_blocks, device=block_table.device
)
indices_to_gather = start_indices.unsqueeze(1) + offsets
return torch.gather(block_table, 1, indices_to_gather)
......@@ -255,7 +255,8 @@ class BlockPool:
)
for i, blk in enumerate(new_full_blocks):
# Some blocks may be null blocks when enabling sparse attention like
# sliding window attention. We skip null blocks here.
# sliding window attention, or Mamba models with prefix-caching in
# align mode. We skip null blocks here.
if blk.is_null:
continue
assert blk.block_hash is None
......
......@@ -75,6 +75,7 @@ class KVCacheCoordinator(ABC):
new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
num_encoder_tokens: int,
total_computed_tokens: int,
num_tokens_main_model: int,
) -> int:
"""
Get the number of blocks needed to be allocated for the request.
......@@ -88,6 +89,9 @@ class KVCacheCoordinator(ABC):
num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention.
total_computed_tokens: Include both local and external tokens.
num_tokens_main_model: The number of tokens for the main model (aka target
model in spec decode). w/o spec decode, it is num_tokens;
with spec decode, it is num_tokens - num_lookahead_tokens.
Returns:
The number of blocks to allocate.
......@@ -98,7 +102,7 @@ class KVCacheCoordinator(ABC):
# For cross-attention, we issue a single static allocation
# of blocks based on the number of encoder input tokens.
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_encoder_tokens, [], 0
request_id, num_encoder_tokens, [], 0, num_encoder_tokens
)
else:
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
......@@ -106,6 +110,7 @@ class KVCacheCoordinator(ABC):
num_tokens,
new_computed_blocks[i],
total_computed_tokens,
num_tokens_main_model,
)
return num_blocks_to_allocate
......@@ -139,6 +144,7 @@ class KVCacheCoordinator(ABC):
self,
request_id: str,
num_tokens: int,
num_tokens_main_model: int,
num_encoder_tokens: int = 0,
) -> tuple[list[KVCacheBlock], ...]:
"""
......@@ -149,6 +155,9 @@ class KVCacheCoordinator(ABC):
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
num_tokens_main_model: The number of tokens for the main model (aka target
model in spec decode). w/o spec decode, it is num_tokens;
with spec decode, it is num_tokens - num_lookahead_tokens.
num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention.
......@@ -161,6 +170,7 @@ class KVCacheCoordinator(ABC):
num_encoder_tokens
if isinstance(manager, CrossAttentionManager)
else num_tokens,
num_tokens_main_model,
)
for manager in self.single_type_managers
)
......
......@@ -307,8 +307,9 @@ class KVCacheManager:
num_local_computed_tokens + num_external_computed_tokens,
self.max_model_len,
)
num_tokens_main_model = total_computed_tokens + num_new_tokens
num_tokens_need_slot = min(
total_computed_tokens + num_new_tokens + num_lookahead_tokens,
num_tokens_main_model + num_lookahead_tokens,
self.max_model_len,
)
......@@ -329,6 +330,7 @@ class KVCacheManager:
num_encoder_tokens=num_encoder_tokens,
total_computed_tokens=num_local_computed_tokens
+ num_external_computed_tokens,
num_tokens_main_model=num_tokens_main_model,
)
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
......@@ -349,7 +351,10 @@ class KVCacheManager:
)
new_blocks = self.coordinator.allocate_new_blocks(
request.request_id, num_tokens_need_slot, num_encoder_tokens
request.request_id,
num_tokens_need_slot,
num_tokens_main_model,
num_encoder_tokens,
)
# P/D: delay caching blocks if we have to recv from
......
......@@ -47,7 +47,7 @@ from vllm.v1.core.sched.output import (
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
from vllm.v1.metrics.stats import (
PrefixCacheStats,
......@@ -226,6 +226,17 @@ class Scheduler(SchedulerInterface):
)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool:
return any(
isinstance(group_spec.kv_cache_spec, MambaSpec)
for group_spec in kv_cache_config.kv_cache_groups
)
self.has_mamba_layers = has_mamba_layers(kv_cache_config)
self.need_mamba_block_aligned_split = (
self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align"
)
self.perf_metrics: ModelMetrics | None = None
if self.log_stats and vllm_config.observability_config.enable_mfu_metrics:
self.perf_metrics = ModelMetrics(vllm_config)
......@@ -250,6 +261,53 @@ class Scheduler(SchedulerInterface):
vllm_config=self.vllm_config,
)
def _mamba_block_aligned_split(
self,
request: Request,
num_new_tokens: int,
num_new_local_computed_tokens: int = 0,
num_external_computed_tokens: int = 0,
) -> int:
assert num_external_computed_tokens == 0, (
"External KV connector is not verified yet"
)
# TODO: need check for resume requests
if request.num_output_tokens == 0: # prefill
# To enable block-aligned caching of the Mamba state, `num_new_tokens`
# must be a multiple of `block_size`.
# As an exception, if `num_new_tokens` is less than `block_size`, the
# state is simply not cached, requiring no special handling.
# Additionally, when Eagle mode is enabled, FullAttn prunes the last
# matching block. To prevent this from causing a Mamba cache miss, the
# last chunk must be larger than `block_size`.
block_size = self.cache_config.block_size
last_cache_position = (
request.num_prompt_tokens - request.num_prompt_tokens % block_size
)
# eagle prune
if self.use_eagle:
last_cache_position = max(last_cache_position - block_size, 0)
num_computed_tokens = (
request.num_computed_tokens
+ num_new_local_computed_tokens
+ num_external_computed_tokens
)
num_computed_tokens_after_sched = num_computed_tokens + num_new_tokens
if num_computed_tokens_after_sched < last_cache_position:
# align to block_size
num_new_tokens = num_new_tokens // block_size * block_size
elif (
num_computed_tokens
< last_cache_position
< num_computed_tokens_after_sched
):
# force to cache the last chunk
num_new_tokens = last_cache_position - num_computed_tokens
else:
# prefill the last few tokens
pass
return num_new_tokens
def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
......@@ -340,6 +398,11 @@ class Scheduler(SchedulerInterface):
shift_computed_tokens=1 if self.use_eagle else 0,
)
if self.need_mamba_block_aligned_split:
num_new_tokens = self._mamba_block_aligned_split(
request, num_new_tokens
)
if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
# reasons:
......@@ -350,6 +413,8 @@ class Scheduler(SchedulerInterface):
# its max_total_tokens or max_model_len.
# 2. The encoder budget is exhausted.
# 3. The encoder cache is exhausted.
# 4. Insufficient budget for a block-aligned chunk in hybrid
# models with mamba cache mode \"align\".
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled.
......@@ -608,6 +673,16 @@ class Scheduler(SchedulerInterface):
# The request cannot be scheduled.
break
if self.need_mamba_block_aligned_split:
num_new_tokens = self._mamba_block_aligned_split(
request,
num_new_tokens,
num_new_local_computed_tokens,
num_external_computed_tokens,
)
if num_new_tokens == 0:
break
# Handles an edge case when P/D Disaggregation
# is used with Spec Decoding where an
# extra block gets allocated which
......
......@@ -66,12 +66,17 @@ class SingleTypeKVCacheManager(ABC):
self.kv_cache_group_id = kv_cache_group_id
self._null_block = block_pool.null_block
@classmethod
def _get_num_evictable_blocks(cls, blocks: Sequence[KVCacheBlock]):
return sum(blk.ref_cnt == 0 and not blk.is_null for blk in blocks)
def get_num_blocks_to_allocate(
self,
request_id: str,
num_tokens: int,
new_computed_blocks: Sequence[KVCacheBlock],
total_computed_tokens: int,
num_tokens_main_model: int,
) -> int:
"""
Get the number of blocks needed to be allocated for the request.
......@@ -84,6 +89,9 @@ class SingleTypeKVCacheManager(ABC):
prefix caching.
total_computed_tokens: Include both local and external computed
tokens.
num_tokens_main_model: The number of tokens for the main model (aka target
model in spec decode). w/o spec decode, it is num_tokens;
with spec decode, it is num_tokens - num_lookahead_tokens.
Returns:
The number of blocks to allocate.
......@@ -121,9 +129,8 @@ class SingleTypeKVCacheManager(ABC):
# If a computed block is an eviction candidate (in the free queue and
# ref_cnt == 0), it will be removed from the free queue when touched by
# the allocated request, so we must count it in the free-capacity check.
num_evictable_blocks = sum(
blk.ref_cnt == 0 and not blk.is_null
for blk in new_computed_blocks[num_skipped_new_computed_blocks:]
num_evictable_blocks = self._get_num_evictable_blocks(
new_computed_blocks[num_skipped_new_computed_blocks:]
)
return num_new_blocks + num_evictable_blocks
......@@ -199,7 +206,7 @@ class SingleTypeKVCacheManager(ABC):
req_blocks.extend(allocated_blocks)
def allocate_new_blocks(
self, request_id: str, num_tokens: int
self, request_id: str, num_tokens: int, num_tokens_main_model: int
) -> list[KVCacheBlock]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
......@@ -209,7 +216,9 @@ class SingleTypeKVCacheManager(ABC):
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
num_tokens_main_model: The number of tokens for the main model (aka target
model in spec decode). w/o spec decode, it is num_tokens;
with spec decode, it is num_tokens - num_lookahead_tokens.
Returns:
The new allocated blocks.
"""
......@@ -450,12 +459,9 @@ class FullAttentionManager(SingleTypeKVCacheManager):
class SlidingWindowManager(SingleTypeKVCacheManager):
def __init__(
self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, **kwargs
) -> None:
super().__init__(kv_cache_spec, block_pool, **kwargs)
def __init__(self, kv_cache_spec: SlidingWindowSpec, **kwargs) -> None:
super().__init__(kv_cache_spec, **kwargs)
self.sliding_window = kv_cache_spec.sliding_window
self._null_block = block_pool.null_block
@classmethod
def find_longest_cache_hit(
......@@ -586,12 +592,9 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
def __init__(
self, kv_cache_spec: ChunkedLocalAttentionSpec, block_pool: BlockPool, **kwargs
) -> None:
super().__init__(kv_cache_spec, block_pool, **kwargs)
def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, **kwargs) -> None:
super().__init__(kv_cache_spec, **kwargs)
self.attention_chunk_size = kv_cache_spec.attention_chunk_size
self._null_block = block_pool.null_block
@classmethod
def find_longest_cache_hit(
......@@ -739,6 +742,17 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
class MambaManager(SingleTypeKVCacheManager):
def __init__(self, kv_cache_spec: MambaSpec, **kwargs) -> None:
super().__init__(kv_cache_spec, **kwargs)
self.mamba_cache_mode = kv_cache_spec.mamba_cache_mode
self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks
if self.mamba_cache_mode == "align":
# Mapping from request ID to the index of the block
# allocated in the previous step
self.last_state_block_idx: dict[str, int] = {}
# The set of the requests that have been allocated blocks
self._allocated_block_reqs: set[str] = set()
@classmethod
def find_longest_cache_hit(
cls,
......@@ -787,6 +801,28 @@ class MambaManager(SingleTypeKVCacheManager):
return computed_blocks
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
assert isinstance(self.kv_cache_spec, MambaSpec)
super().remove_skipped_blocks(request_id, num_computed_tokens)
if self.mamba_cache_mode == "align":
# `last_state_block_idx` refers to the block index allocated two steps ago.
# The block allocated in the previous step is used to copy Mamba states
# into the block allocated in the current step; the earlier block is
# no longer needed and should be freed here.
last_state_block_idx = self.last_state_block_idx.get(request_id)
# Blocks allocated during prefill may be non-contiguous. Use
# `last_state_block_idx` to free the appropriate block and replace it
# with a null block.
if (
last_state_block_idx is not None
and last_state_block_idx
< cdiv(num_computed_tokens, self.block_size) - 1
):
blocks = self.req_to_blocks[request_id]
if blocks[last_state_block_idx] != self._null_block:
self.block_pool.free_blocks([blocks[last_state_block_idx]])
blocks[last_state_block_idx] = self._null_block
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
cascade attention is not supported by mamba
......@@ -799,31 +835,134 @@ class MambaManager(SingleTypeKVCacheManager):
num_tokens: int,
new_computed_blocks: Sequence[KVCacheBlock],
total_computed_tokens: int,
num_tokens_main_model: int,
) -> int:
# Allocate extra `num_speculative_blocks` blocks for
# speculative decoding (MTP/EAGLE) with linear attention.
assert isinstance(self.kv_cache_spec, MambaSpec)
if self.kv_cache_spec.num_speculative_blocks > 0:
num_tokens += (
self.kv_cache_spec.block_size
* self.kv_cache_spec.num_speculative_blocks
if self.mamba_cache_mode != "align":
# Allocate extra `num_speculative_blocks` blocks for
# speculative decoding (MTP/EAGLE) with linear attention.
if self.num_speculative_blocks > 0:
num_tokens += (
self.kv_cache_spec.block_size * self.num_speculative_blocks
)
return super().get_num_blocks_to_allocate(
request_id,
num_tokens,
new_computed_blocks,
total_computed_tokens,
num_tokens_main_model,
)
return super().get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks, total_computed_tokens
)
else:
# We don't allocate blocks for lookahead tokens in align mode, because if
# x * block_size tokens are scheduled, num_tokens is
# x * block_size + num_lookahead_tokens and breaks the alignment.
# We can ignore lookahead tokens because current draft models don't have
# mamba layers.
num_tokens = num_tokens_main_model
num_required_blocks = (
cdiv(num_tokens, self.block_size) + self.num_speculative_blocks
)
num_new_blocks = (
num_required_blocks
- len(new_computed_blocks)
- len(self.req_to_blocks[request_id])
)
if num_new_blocks > 0:
if request_id in self._allocated_block_reqs:
# Old request. Needs at most 1 more blocks as we can reuse the
# speculative blocks in previous step.
num_new_blocks = 1
else:
# First prefill. Allocate 1 block for running state and the
# speculative blocks.
num_new_blocks = 1 + self.num_speculative_blocks
num_evictable_computed_blocks = self._get_num_evictable_blocks(
new_computed_blocks
)
return num_new_blocks + num_evictable_computed_blocks
def allocate_new_blocks(
self, request_id: str, num_tokens: int
self, request_id: str, num_tokens: int, num_tokens_main_model: int
) -> list[KVCacheBlock]:
# Allocate extra `num_speculative_blocks` blocks for
# speculative decoding (MTP/EAGLE) with linear attention.
assert isinstance(self.kv_cache_spec, MambaSpec)
if self.kv_cache_spec.num_speculative_blocks > 0:
num_tokens += (
self.kv_cache_spec.block_size
* self.kv_cache_spec.num_speculative_blocks
if self.mamba_cache_mode != "align":
# Allocate extra `num_speculative_blocks` blocks for
# speculative decoding (MTP/EAGLE) with linear attention.
if self.num_speculative_blocks > 0:
num_tokens += self.block_size * self.num_speculative_blocks
return super().allocate_new_blocks(
request_id, num_tokens, num_tokens_main_model
)
return super().allocate_new_blocks(request_id, num_tokens)
else:
# We don't allocate blocks for lookahead tokens in align mode, because if
# x * block_size tokens are scheduled, num_tokens is
# x * block_size + num_lookahead_tokens and breaks the alignment.
# We can ignore lookahead tokens because current draft models don't have
# mamba layers.
num_tokens = num_tokens_main_model
req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id]
num_required_blocks = (
cdiv(num_tokens, self.block_size) + self.num_speculative_blocks
)
if num_required_blocks == len(req_blocks):
return []
else:
assert num_required_blocks > len(req_blocks), (
"num_required_blocks "
f"{num_required_blocks} < len(req_blocks) {len(req_blocks)}"
)
prev_block_len = len(req_blocks)
blocks_allocated = request_id in self._allocated_block_reqs
# Record the last state block
if blocks_allocated:
# We always save the running state at the last
# (1 + num_speculative_blocks) block
self.last_state_block_idx[request_id] = (
prev_block_len - 1 - self.num_speculative_blocks
)
elif prev_block_len > 0:
# When a new request hits the prefix cache, the last block
# saves the hit state.
self.last_state_block_idx[request_id] = prev_block_len - 1
num_skipped_blocks = (
num_required_blocks - self.num_speculative_blocks - 1
)
# null blocks
if prev_block_len < num_skipped_blocks:
req_blocks.extend(
[
self._null_block
for _ in range(prev_block_len, num_skipped_blocks)
]
)
if blocks_allocated:
# reuse previous speculative blocks in this step
for block_idx in range(
prev_block_len - self.num_speculative_blocks, prev_block_len
):
if block_idx < num_skipped_blocks:
req_blocks.append(req_blocks[block_idx])
req_blocks[block_idx] = self._null_block
else:
break
num_new_blocks = num_required_blocks - len(req_blocks)
if blocks_allocated:
assert num_new_blocks <= 1
else:
assert num_new_blocks <= self.num_speculative_blocks + 1
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
self._allocated_block_reqs.add(request_id)
return req_blocks[prev_block_len:]
def free(self, request_id: str) -> None:
if self.mamba_cache_mode == "align":
self._allocated_block_reqs.discard(request_id)
self.last_state_block_idx.pop(request_id, None)
super().free(request_id)
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
"""
......
......@@ -276,6 +276,7 @@ class MambaSpec(KVCacheSpec):
dtypes: tuple[torch.dtype]
page_size_padded: int | None = None
mamba_type: str = "mamba2"
mamba_cache_mode: str = "none"
num_speculative_blocks: int = 0
@property
......@@ -290,8 +291,13 @@ class MambaSpec(KVCacheSpec):
return page_size
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
if vllm_config.cache_config.mamba_cache_mode == "all":
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
elif vllm_config.cache_config.mamba_cache_mode == "align":
return self.page_size_bytes * (2 + self.num_speculative_blocks)
else:
return self.page_size_bytes * (1 + self.num_speculative_blocks)
@dataclass(frozen=True)
......
......@@ -8,6 +8,7 @@ from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.cp_utils import get_total_cp_world_size
logger = init_logger(__name__)
......@@ -261,47 +262,45 @@ class MultiGroupBlockTable:
device: torch.device,
block_sizes: list[int],
kernel_block_sizes: list[int],
num_speculative_tokens: int = 0,
max_num_blocks: list[int] | None = None,
cp_kv_cache_interleave_size: int = 1,
) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
try:
pcp_world_size = get_pcp_group().world_size
except AssertionError:
# PCP might not be initialized in testing
pcp_world_size = 1
try:
dcp_world_size = get_dcp_group().world_size
except AssertionError:
# DCP might not be initialized in testing
dcp_world_size = 1
if len(kernel_block_sizes) != len(block_sizes):
raise ValueError(
f"kernel_block_sizes length ({len(kernel_block_sizes)}) "
f"must match block_sizes length ({len(block_sizes)})"
)
total_cp_world_size = dcp_world_size * pcp_world_size
if max_num_blocks is None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
total_cp_world_size = get_total_cp_world_size()
max_num_blocks = [
cdiv(max_model_len, block_size * total_cp_world_size)
for block_size in block_sizes
]
if len(max_num_blocks) != len(block_sizes):
raise ValueError(
f"max_num_blocks length ({len(max_num_blocks)}) "
f"must match block_sizes length ({len(block_sizes)})"
)
self.block_tables = [
BlockTable(
block_size,
max_num_reqs,
max(
cdiv(max_model_len, block_size * total_cp_world_size),
1 + num_speculative_tokens,
),
max_num_blocks_per_req,
max_num_batched_tokens,
pin_memory,
device,
kernel_block_size,
cp_kv_cache_interleave_size,
)
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
for block_size, kernel_block_size, max_num_blocks_per_req in zip(
block_sizes, kernel_block_sizes, max_num_blocks
)
]
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
......
......@@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Any, cast
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed import get_dcp_group, get_pcp_group
if TYPE_CHECKING:
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
......@@ -40,3 +41,17 @@ def check_attention_cp_compatibility(vllm_config: VllmConfig) -> None:
f"but the impl {layer_impl.__class__.__name__} "
"does not support PCP."
)
def get_total_cp_world_size():
try:
pcp_world_size = get_pcp_group().world_size
except AssertionError:
# PCP might not be initialized in testing
pcp_world_size = 1
try:
dcp_world_size = get_dcp_group().world_size
except AssertionError:
# DCP might not be initialized in testing
dcp_world_size = 1
return dcp_world_size * pcp_world_size
......@@ -89,11 +89,11 @@ class InputBatch:
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
kernel_block_sizes: list[int],
max_num_blocks_per_req: list[int] | None = None,
logitsprocs: LogitsProcessors | None = None,
logitsprocs_need_output_token_ids: bool = False,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
cp_kv_cache_interleave_size: int = 1,
):
self.is_pooling_model = is_pooling_model
......@@ -146,7 +146,7 @@ class InputBatch:
device=device,
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
num_speculative_tokens=num_speculative_tokens,
max_num_blocks=max_num_blocks_per_req,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment