Commit b3230e1a authored by Yongye Zhu's avatar Yongye Zhu Committed by simon-mo
Browse files
parent 03df0fb5
...@@ -102,6 +102,7 @@ class PallasAttentionBackend(AttentionBackend): ...@@ -102,6 +102,7 @@ class PallasAttentionBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[int, ...]:
padded_head_size = cdiv( padded_head_size = cdiv(
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
......
...@@ -360,6 +360,7 @@ class AiterFlashAttentionBackend(AttentionBackend): ...@@ -360,6 +360,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[int, ...]:
if block_size % 16 != 0: if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.") raise ValueError("Block size must be a multiple of 16.")
......
...@@ -68,6 +68,7 @@ class TreeAttentionBackend(AttentionBackend): ...@@ -68,6 +68,7 @@ class TreeAttentionBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[int, ...]:
if block_size % 16 != 0: if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.") raise ValueError("Block size must be a multiple of 16.")
......
...@@ -171,6 +171,7 @@ class TritonAttentionBackend(AttentionBackend): ...@@ -171,6 +171,7 @@ class TritonAttentionBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[int, ...]:
if block_size % 16 != 0: if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.") raise ValueError("Block size must be a multiple of 16.")
......
...@@ -106,6 +106,7 @@ class XFormersAttentionBackend(AttentionBackend): ...@@ -106,6 +106,7 @@ class XFormersAttentionBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[int, ...]:
if block_size % 16 != 0: if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.") raise ValueError("Block size must be a multiple of 16.")
......
...@@ -1103,7 +1103,9 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): ...@@ -1103,7 +1103,9 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
kv_cache_spec: The kv cache spec of each attention layer in the model kv_cache_spec: The kv cache spec of each attention layer in the model
""" """
if is_kv_cache_spec_uniform(kv_cache_spec): if is_kv_cache_spec_uniform(
kv_cache_spec) or UniformTypeKVCacheSpecs.is_uniform_type(
kv_cache_spec):
return return
logger.warning( logger.warning(
...@@ -1128,7 +1130,6 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): ...@@ -1128,7 +1130,6 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
num_kv_heads=spec.num_kv_heads, num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size, head_size=spec.head_size,
dtype=spec.dtype, dtype=spec.dtype,
use_mla=spec.use_mla,
sliding_window=spec.sliding_window, sliding_window=spec.sliding_window,
) )
elif isinstance(spec, ChunkedLocalAttentionSpec): elif isinstance(spec, ChunkedLocalAttentionSpec):
...@@ -1137,11 +1138,11 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): ...@@ -1137,11 +1138,11 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
num_kv_heads=spec.num_kv_heads, num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size, head_size=spec.head_size,
dtype=spec.dtype, dtype=spec.dtype,
use_mla=spec.use_mla,
attention_chunk_size=spec.attention_chunk_size, attention_chunk_size=spec.attention_chunk_size,
) )
if not is_kv_cache_spec_uniform(kv_cache_spec): if not (is_kv_cache_spec_uniform(kv_cache_spec)
or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec)):
raise ValueError("Hybrid KV cache manager is disabled but failed to " raise ValueError("Hybrid KV cache manager is disabled but failed to "
"convert the KV cache specs to one unified type.") "convert the KV cache specs to one unified type.")
......
...@@ -10,7 +10,7 @@ from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock ...@@ -10,7 +10,7 @@ from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
CrossAttentionSpec, FullAttentionSpec, CrossAttentionSpec, FullAttentionSpec,
KVCacheSpec, MambaSpec, KVCacheSpec, MambaSpec,
SlidingWindowSpec) MLAAttentionSpec, SlidingWindowSpec)
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -656,6 +656,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager): ...@@ -656,6 +656,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager, FullAttentionSpec: FullAttentionManager,
MLAAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager, SlidingWindowSpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
MambaSpec: MambaManager, MambaSpec: MambaManager,
......
...@@ -59,13 +59,10 @@ class AttentionSpec(KVCacheSpec): ...@@ -59,13 +59,10 @@ class AttentionSpec(KVCacheSpec):
num_kv_heads: int num_kv_heads: int
head_size: int head_size: int
dtype: torch.dtype dtype: torch.dtype
use_mla: bool
@property @property
def page_size_bytes(self) -> int: def page_size_bytes(self) -> int:
# For MLA we only store a single latent vector return 2 * self.block_size * self.num_kv_heads * self.head_size \
coef = 1 if self.use_mla else 2
return coef * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype) * get_dtype_size(self.dtype)
...@@ -118,12 +115,13 @@ class FullAttentionSpec(AttentionSpec): ...@@ -118,12 +115,13 @@ class FullAttentionSpec(AttentionSpec):
if spec.sliding_window is not None) if spec.sliding_window is not None)
attention_chunk_size = set(spec.attention_chunk_size for spec in specs attention_chunk_size = set(spec.attention_chunk_size for spec in specs
if spec.attention_chunk_size is not None) if spec.attention_chunk_size is not None)
assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"MLAAttentionSpec should be merged in MLAAttentionSpec.merge")
merged_spec = cls( merged_spec = cls(
block_size=specs[0].block_size, block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads, num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size, head_size=specs[0].head_size,
dtype=specs[0].dtype, dtype=specs[0].dtype,
use_mla=specs[0].use_mla,
sliding_window=cls.merge_window_sizes(sliding_window), sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
) )
...@@ -140,6 +138,38 @@ class FullAttentionSpec(AttentionSpec): ...@@ -140,6 +138,38 @@ class FullAttentionSpec(AttentionSpec):
return merged_spec return merged_spec
@dataclass(frozen=True)
class MLAAttentionSpec(FullAttentionSpec):
# TODO(Lucas/Chen): less hacky way to do this
cache_dtype_str: Optional[str] = None
@property
def page_size_bytes(self) -> int:
if self.cache_dtype_str == "fp8_ds_mla":
# See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
# for details.
return self.block_size * 656
return self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype)
@classmethod
def merge(cls, specs: list[Self]) -> Self:
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be "
"MLAAttentionSpec.")
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
assert len(cache_dtype_str_set) == 1, (
"All attention layers in the same KV cache group must use the same "
"quantization method.")
return cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
dtype=specs[0].dtype,
cache_dtype_str=cache_dtype_str_set.pop(),
)
@dataclass(frozen=True) @dataclass(frozen=True)
class ChunkedLocalAttentionSpec(AttentionSpec): class ChunkedLocalAttentionSpec(AttentionSpec):
attention_chunk_size: int attention_chunk_size: int
...@@ -163,9 +193,6 @@ class ChunkedLocalAttentionSpec(AttentionSpec): ...@@ -163,9 +193,6 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
class SlidingWindowSpec(AttentionSpec): class SlidingWindowSpec(AttentionSpec):
sliding_window: int sliding_window: int
def __post_init__(self):
assert not self.use_mla, "MLA is not supported for sliding window"
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
assert vllm_config.parallel_config.decode_context_parallel_size == 1, \ assert vllm_config.parallel_config.decode_context_parallel_size == 1, \
"DCP not support sliding window." "DCP not support sliding window."
...@@ -266,9 +293,13 @@ class UniformTypeKVCacheSpecs(KVCacheSpec): ...@@ -266,9 +293,13 @@ class UniformTypeKVCacheSpecs(KVCacheSpec):
# Different block sizes, not uniform. # Different block sizes, not uniform.
return False return False
one_spec = next(iter(kv_cache_specs.values())) one_spec = next(iter(kv_cache_specs.values()))
if isinstance(one_spec, (FullAttentionSpec, CrossAttentionSpec)): if isinstance(one_spec, FullAttentionSpec):
return all(
isinstance(spec, FullAttentionSpec)
for spec in kv_cache_specs.values())
elif isinstance(one_spec, CrossAttentionSpec):
return all( return all(
isinstance(spec, type(one_spec)) isinstance(spec, CrossAttentionSpec)
for spec in kv_cache_specs.values()) for spec in kv_cache_specs.values())
elif isinstance(one_spec, SlidingWindowSpec): elif isinstance(one_spec, SlidingWindowSpec):
return all( return all(
......
This diff is collapsed.
This diff is collapsed.
...@@ -530,7 +530,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -530,7 +530,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window, sliding_window=attn_module.sliding_window,
use_mla=False,
) )
else: else:
kv_cache_spec[layer_name] = FullAttentionSpec( kv_cache_spec[layer_name] = FullAttentionSpec(
...@@ -538,7 +537,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -538,7 +537,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
use_mla=False,
) )
elif attn_module.attn_type in (AttentionType.ENCODER, elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY): AttentionType.ENCODER_ONLY):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment