Commit 6fc61a0d authored by zhuwenwen's avatar zhuwenwen
Browse files

fix fa interface and kvcache

not supported FlashMLASchedMeta
parent ae59e10f
...@@ -1414,9 +1414,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1414,9 +1414,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
self.flash_attn_varlen_func = flash_attn_varlen_func self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version() self.vllm_flash_attn_version = get_flash_attn_version()
if self.vllm_flash_attn_version is not None: if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = functools.partial( if current_platform.is_rocm():
flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version self.flash_attn_varlen_func = functools.partial(
) flash_attn_varlen_func
)
else:
self.flash_attn_varlen_func = functools.partial(
flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version
)
# For MLA the v head dim is smaller than qk head dim so we pad out # For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do # v with 0s to match the qk head dim for attention backends that do
......
...@@ -893,7 +893,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -893,7 +893,10 @@ class FlashAttentionImpl(AttentionImpl):
): ):
return return
key_cache, value_cache = kv_cache.unbind(0) if not current_platform.is_rocm():
key_cache, value_cache = kv_cache.unbind(0)
else:
key_cache, value_cache = kv_cache
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer. # Skip this if sharing KV cache with an earlier attention layer.
...@@ -902,16 +905,43 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -902,16 +905,43 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash # and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of # op uses the slot_mapping's shape to determine the number of
# actual tokens. # actual tokens.
reshape_and_cache_flash( if not current_platform.is_rocm():
key, reshape_and_cache_flash(
value, key,
key_cache, value,
value_cache, key_cache,
slot_mapping, value_cache,
self.kv_cache_dtype, slot_mapping,
layer._k_scale, self.kv_cache_dtype,
layer._v_scale, layer._k_scale,
) layer._v_scale,
)
else:
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE and key.dtype == value.dtype == torch.float16:
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale
)
else:
from vllm.v1.attention.backends.fa_utils import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def _forward_with_dcp( def _forward_with_dcp(
self, self,
......
...@@ -27,6 +27,7 @@ from vllm.v1.attention.backend import ( ...@@ -27,6 +27,7 @@ from vllm.v1.attention.backend import (
AttentionType, AttentionType,
MultipleOf, MultipleOf,
) )
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
reshape_attn_output_for_spec_decode, reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode, reshape_query_for_spec_decode,
...@@ -41,7 +42,6 @@ from vllm.v1.attention.ops.flashmla import ( ...@@ -41,7 +42,6 @@ from vllm.v1.attention.ops.flashmla import (
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm import envs from vllm import envs
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -320,6 +320,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -320,6 +320,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
cache_seqlens=attn_metadata.decode.seq_lens, cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank, head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=scheduler_metadata, tile_scheduler_metadata=scheduler_metadata,
num_splits=scheduler_metadata.num_splits,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
is_fp8_kvcache=False, is_fp8_kvcache=False,
......
...@@ -100,7 +100,7 @@ def _raise_flashmla_unavailable(*_args, **_kwargs): ...@@ -100,7 +100,7 @@ def _raise_flashmla_unavailable(*_args, **_kwargs):
if _is_flashmla_available()[0]: if _is_flashmla_available()[0]:
if current_platform.is_rocm(): if current_platform.is_rocm():
from flash_mla.flash_mla_interface import ( # noqa: F401 from flash_mla.flash_mla_interface import ( # noqa: F401
# FlashMLASchedMeta, FlashMLASchedMeta, # need new flashmla
# flash_attn_varlen_func, # flash_attn_varlen_func,
# flash_attn_varlen_kvpacked_func, # flash_attn_varlen_kvpacked_func,
# flash_attn_varlen_qkvpacked_func, # flash_attn_varlen_qkvpacked_func,
...@@ -122,7 +122,7 @@ else: ...@@ -122,7 +122,7 @@ else:
class FlashMLASchedMeta: # type: ignore[no-redef] class FlashMLASchedMeta: # type: ignore[no-redef]
pass pass
flash_attn_varlen_func = _raise_flashmla_unavailable # type: ignore[assignment] flash_attn_varlen_func = _raise_flashmla_unavailable # type: ignore[assignment]
flash_attn_varlen_kvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment] flash_attn_varlen_kvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment]
flash_attn_varlen_qkvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment] flash_attn_varlen_qkvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment