Commit b3230e1a authored by Yongye Zhu's avatar Yongye Zhu Committed by simon-mo
Browse files
parent 03df0fb5
......@@ -102,6 +102,7 @@ class PallasAttentionBackend(AttentionBackend):
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
padded_head_size = cdiv(
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
......
......@@ -360,6 +360,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
......
......@@ -68,6 +68,7 @@ class TreeAttentionBackend(AttentionBackend):
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
......
......@@ -171,6 +171,7 @@ class TritonAttentionBackend(AttentionBackend):
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
......
......@@ -106,6 +106,7 @@ class XFormersAttentionBackend(AttentionBackend):
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
......
......@@ -1103,7 +1103,9 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
kv_cache_spec: The kv cache spec of each attention layer in the model
"""
if is_kv_cache_spec_uniform(kv_cache_spec):
if is_kv_cache_spec_uniform(
kv_cache_spec) or UniformTypeKVCacheSpecs.is_uniform_type(
kv_cache_spec):
return
logger.warning(
......@@ -1128,7 +1130,6 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
dtype=spec.dtype,
use_mla=spec.use_mla,
sliding_window=spec.sliding_window,
)
elif isinstance(spec, ChunkedLocalAttentionSpec):
......@@ -1137,11 +1138,11 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
dtype=spec.dtype,
use_mla=spec.use_mla,
attention_chunk_size=spec.attention_chunk_size,
)
if not is_kv_cache_spec_uniform(kv_cache_spec):
if not (is_kv_cache_spec_uniform(kv_cache_spec)
or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec)):
raise ValueError("Hybrid KV cache manager is disabled but failed to "
"convert the KV cache specs to one unified type.")
......
......@@ -10,7 +10,7 @@ from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
CrossAttentionSpec, FullAttentionSpec,
KVCacheSpec, MambaSpec,
SlidingWindowSpec)
MLAAttentionSpec, SlidingWindowSpec)
from vllm.v1.request import Request
......@@ -656,6 +656,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
MLAAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
MambaSpec: MambaManager,
......
......@@ -59,13 +59,10 @@ class AttentionSpec(KVCacheSpec):
num_kv_heads: int
head_size: int
dtype: torch.dtype
use_mla: bool
@property
def page_size_bytes(self) -> int:
# For MLA we only store a single latent vector
coef = 1 if self.use_mla else 2
return coef * self.block_size * self.num_kv_heads * self.head_size \
return 2 * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype)
......@@ -118,12 +115,13 @@ class FullAttentionSpec(AttentionSpec):
if spec.sliding_window is not None)
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
if spec.attention_chunk_size is not None)
assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"MLAAttentionSpec should be merged in MLAAttentionSpec.merge")
merged_spec = cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
dtype=specs[0].dtype,
use_mla=specs[0].use_mla,
sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
)
......@@ -140,6 +138,38 @@ class FullAttentionSpec(AttentionSpec):
return merged_spec
@dataclass(frozen=True)
class MLAAttentionSpec(FullAttentionSpec):
# TODO(Lucas/Chen): less hacky way to do this
cache_dtype_str: Optional[str] = None
@property
def page_size_bytes(self) -> int:
if self.cache_dtype_str == "fp8_ds_mla":
# See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
# for details.
return self.block_size * 656
return self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype)
@classmethod
def merge(cls, specs: list[Self]) -> Self:
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be "
"MLAAttentionSpec.")
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
assert len(cache_dtype_str_set) == 1, (
"All attention layers in the same KV cache group must use the same "
"quantization method.")
return cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
dtype=specs[0].dtype,
cache_dtype_str=cache_dtype_str_set.pop(),
)
@dataclass(frozen=True)
class ChunkedLocalAttentionSpec(AttentionSpec):
attention_chunk_size: int
......@@ -163,9 +193,6 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
class SlidingWindowSpec(AttentionSpec):
sliding_window: int
def __post_init__(self):
assert not self.use_mla, "MLA is not supported for sliding window"
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
assert vllm_config.parallel_config.decode_context_parallel_size == 1, \
"DCP not support sliding window."
......@@ -266,9 +293,13 @@ class UniformTypeKVCacheSpecs(KVCacheSpec):
# Different block sizes, not uniform.
return False
one_spec = next(iter(kv_cache_specs.values()))
if isinstance(one_spec, (FullAttentionSpec, CrossAttentionSpec)):
if isinstance(one_spec, FullAttentionSpec):
return all(
isinstance(spec, FullAttentionSpec)
for spec in kv_cache_specs.values())
elif isinstance(one_spec, CrossAttentionSpec):
return all(
isinstance(spec, type(one_spec))
isinstance(spec, CrossAttentionSpec)
for spec in kv_cache_specs.values())
elif isinstance(one_spec, SlidingWindowSpec):
return all(
......
......@@ -17,6 +17,7 @@ from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
......@@ -31,6 +32,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
logger = init_logger(__name__)
......@@ -51,6 +53,7 @@ class EagleProposer:
self.method = self.speculative_config.method
self.runner = runner
self.device = device
self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
......@@ -178,20 +181,30 @@ class EagleProposer:
assert self.runner is not None
# Select the correct attention metadata builders for EAGLE layers.
# Get the attention metadata builders once and reuse for later.
builder = (self._get_attention_metadata_builder()
if self.attn_metadata_builder is None else
self.attn_metadata_builder)
attn_metadata = builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata,
draft_index=0)
# FIXME: need to consider multiple kv_cache_groups
ubatch_id = dbo_current_ubatch_id()
attn_metadata_builder = \
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0)
# FIXME: support hybrid kv for draft model (remove separate indexer)
if self.draft_indexer_metadata_builder:
draft_indexer_metadata = (
self.draft_indexer_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=0,
))
else:
draft_indexer_metadata = None
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
for layer_name in self.indexer_layer_names:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
......@@ -323,7 +336,7 @@ class EagleProposer:
exceeds_max_model_len, PADDING_SLOT_ID)
# Rebuild attention metadata
attn_metadata = builder.build_for_drafting( # type: ignore
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1)
for layer_name in self.attn_layer_names:
......@@ -794,6 +807,10 @@ class EagleProposer:
self.vllm_config.speculative_config.draft_model_config
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
# FIXME: support hybrid kv for draft model
target_indexer_layer_names = set(
get_layers_from_vllm_config(self.vllm_config,
DeepseekV32IndexerCache).keys())
from vllm.compilation.backends import set_model_tag
with set_model_tag("eagle_head"):
......@@ -803,8 +820,25 @@ class EagleProposer:
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
target_attn_layer_names)
indexer_layers = get_layers_from_vllm_config(self.vllm_config,
DeepseekV32IndexerCache)
draft_indexer_layer_names = (indexer_layers.keys() -
target_indexer_layer_names)
self.attn_layer_names = list(draft_attn_layer_names)
self.indexer_layer_names = list(draft_indexer_layer_names)
if self.indexer_layer_names:
first_layer = self.indexer_layer_names[0]
self.draft_indexer_metadata_builder = (
indexer_layers[first_layer].get_attn_backend().get_builder_cls(
)(
indexer_layers[first_layer].get_kv_cache_spec(),
self.indexer_layer_names,
self.vllm_config,
self.device,
))
else:
self.draft_indexer_metadata_builder = None
if supports_multimodal(target_model):
# handle multimodality
......
......@@ -40,6 +40,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
......@@ -80,7 +81,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
EncoderOnlyAttentionSpec,
FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
MambaSpec, SlidingWindowSpec,
MambaSpec, MLAAttentionSpec,
SlidingWindowSpec,
UniformTypeKVCacheSpecs)
# yapf: enable
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
......@@ -3068,7 +3070,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata_i = (attn_group\
.get_metadata_builder(ubatch_id=ubid)\
.build_for_cudagraph_capture(common_attn_metadata))
for layer_name in kv_cache_group_spec.layer_names:
for layer_name in attn_group.layer_names:
assert type(attn_metadata) is list
attn_metadata[ubid][
layer_name] = attn_metadata_i
......@@ -3076,7 +3078,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert type(attn_metadata) is dict
attn_metadata_i = attn_group.get_metadata_builder()\
.build_for_cudagraph_capture(common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names:
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
with self.maybe_dummy_run_with_lora(self.lora_config,
......@@ -3823,8 +3825,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=self.cache_config.cache_dtype)
dtype = kv_cache_spec.dtype
try:
kv_cache_stride_order = \
......@@ -4010,7 +4015,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
Add encoder-only layers to the KV cache config.
"""
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
encoder_only_attn_specs: dict[AttentionSpec,
list[str]] = defaultdict(list)
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
......@@ -4020,8 +4024,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
dtype=self.kv_cache_dtype)
encoder_only_attn_specs[attn_spec].append(layer_name)
self.runner_only_attn_layers.add(layer_name)
if len(encoder_only_attn_specs) > 0:
......@@ -4043,6 +4046,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
cache_dtype_str = self.vllm_config.cache_config.cache_dtype
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
......@@ -4062,13 +4066,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# the attention backends
if attn_module.attn_type == AttentionType.DECODER:
if attn_module.sliding_window is not None:
assert not use_mla, "MLA is not supported for sliding" \
"window"
kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
use_mla=use_mla)
sliding_window=attn_module.sliding_window)
elif use_mla:
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
cache_dtype_str=cache_dtype_str)
elif self.attention_chunk_size is not None \
and isinstance(attn_module, ChunkedLocalAttention):
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
......@@ -4076,22 +4088,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
attention_chunk_size=self.attention_chunk_size,
use_mla=use_mla)
attention_chunk_size=self.attention_chunk_size)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
dtype=self.kv_cache_dtype)
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
kv_cache_spec[layer_name] = CrossAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
dtype=self.kv_cache_dtype)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.
......@@ -4128,6 +4137,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.speculative_config.num_speculative_tokens
if self.speculative_config else 0),
)
ds_indexer_layers = get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache)
for layer_name, ds_indexer_module in ds_indexer_layers.items():
kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec()
return kv_cache_spec
......
......@@ -530,7 +530,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
use_mla=False,
)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
......@@ -538,7 +537,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=False,
)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment