Unverified Commit 05181cc5 authored by Asaf Joseph Gardin's avatar Asaf Joseph Gardin Committed by GitHub
Browse files

[Hybrid] Add mamba_block_size to Engine Args (#27289)


Signed-off-by: default avatarasafg <39553475+Josephasafg@users.noreply.github.com>
parent 259504e1
......@@ -5,7 +5,7 @@ import hashlib
from dataclasses import field
from typing import TYPE_CHECKING, Any, Literal
from pydantic import Field, SkipValidation, field_validator
from pydantic import Field, SkipValidation, field_validator, model_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
......@@ -90,8 +90,10 @@ class CacheConfig:
mamba_page_size_padded: int | None = None
""" Optional override for mamba page size; used by hybrid mamba/attention
models to ensure exact alignment with attention page size."""
mamba_block_size: int | None = None
"""Size of a contiguous cache block in number of tokens for mamba cache."""
mamba_block_size: int | None = Field(default=None, gt=0)
"""Size of a contiguous cache block in number of tokens for mamba cache.
Can be set only when prefix caching is enabled.
Value must be a multiple of 8 to align with causal_conv1d kernel."""
mamba_cache_dtype: MambaDType = "auto"
"""The data type to use for the Mamba cache (both the conv as well as the
ssm state). If set to 'auto', the data type will be inferred from the model
......@@ -183,3 +185,11 @@ class CacheConfig:
raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory:
logger.warning("Possibly too large swap space. %s", msg)
@model_validator(mode="after")
def validate_mamba_block_size(self) -> "CacheConfig":
if self.mamba_block_size is not None and not self.enable_prefix_caching:
raise ValueError(
"--mamba-block-size can only be set with --enable-prefix-caching"
)
return self
......@@ -535,6 +535,7 @@ class EngineArgs:
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
......@@ -893,6 +894,9 @@ class EngineArgs:
cache_group.add_argument(
"--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
)
cache_group.add_argument(
"--mamba-block-size", **cache_kwargs["mamba_block_size"]
)
# Multimodal related configs
multimodal_kwargs = get_kwargs(MultiModalConfig)
......@@ -1390,6 +1394,7 @@ class EngineArgs:
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
mamba_cache_dtype=self.mamba_cache_dtype,
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
mamba_block_size=self.mamba_block_size,
)
ray_runtime_env = None
......
......@@ -291,8 +291,7 @@ class MambaModelConfig(VerifyAndUpdateConfig):
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
# Set mamba block size to max_model_len (this may get
# override by prefix caching logic later)
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len
if cache_config.enable_prefix_caching:
......@@ -333,6 +332,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
if not envs.VLLM_USE_V1:
return
# Save the user input before it gets modified by MambaModelConfig
mamba_block_size = vllm_config.cache_config.mamba_block_size
# Enable FULL_AND_PIECEWISE by default
MambaModelConfig.verify_and_update_config(vllm_config)
......@@ -386,7 +387,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# With prefix caching, select attention block size to
# optimize for mamba kernel performance
# mamba SSD kernel uses a chunk_size, e.g. 256
# Mamba2 SSD kernel uses a chunk_size, e.g. 256
# Align the block to the kernel: use lowest multiple of chunk_size
# of attention tokens that would fit mamba_page_size:
# e.g. for mamba page size = 788kB
......@@ -404,7 +405,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
def lcm(a, b):
return a * b // gcd(a, b)
base_chunk_size = model_config.get_mamba_chunk_size()
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment