Commit 6fc61a0d authored by zhuwenwen's avatar zhuwenwen
Browse files

fix fa interface and kvcache

not supported FlashMLASchedMeta
parent ae59e10f
......@@ -1414,9 +1414,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version()
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = functools.partial(
flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version
)
if current_platform.is_rocm():
self.flash_attn_varlen_func = functools.partial(
flash_attn_varlen_func
)
else:
self.flash_attn_varlen_func = functools.partial(
flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version
)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
......
......@@ -893,7 +893,10 @@ class FlashAttentionImpl(AttentionImpl):
):
return
key_cache, value_cache = kv_cache.unbind(0)
if not current_platform.is_rocm():
key_cache, value_cache = kv_cache.unbind(0)
else:
key_cache, value_cache = kv_cache
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
......@@ -902,16 +905,43 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if not current_platform.is_rocm():
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE and key.dtype == value.dtype == torch.float16:
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale
)
else:
from vllm.v1.attention.backends.fa_utils import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def _forward_with_dcp(
self,
......
......@@ -27,6 +27,7 @@ from vllm.v1.attention.backend import (
AttentionType,
MultipleOf,
)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import (
reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode,
......@@ -41,7 +42,6 @@ from vllm.v1.attention.ops.flashmla import (
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm import envs
from vllm.platforms import current_platform
logger = init_logger(__name__)
......@@ -320,6 +320,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=scheduler_metadata,
num_splits=scheduler_metadata.num_splits,
softmax_scale=self.scale,
causal=True,
is_fp8_kvcache=False,
......
......@@ -100,7 +100,7 @@ def _raise_flashmla_unavailable(*_args, **_kwargs):
if _is_flashmla_available()[0]:
if current_platform.is_rocm():
from flash_mla.flash_mla_interface import ( # noqa: F401
# FlashMLASchedMeta,
FlashMLASchedMeta, # need new flashmla
# flash_attn_varlen_func,
# flash_attn_varlen_kvpacked_func,
# flash_attn_varlen_qkvpacked_func,
......@@ -122,7 +122,7 @@ else:
class FlashMLASchedMeta: # type: ignore[no-redef]
pass
flash_attn_varlen_func = _raise_flashmla_unavailable # type: ignore[assignment]
flash_attn_varlen_kvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment]
flash_attn_varlen_qkvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment