Commit 1cb851b0 authored by zhuwenwen's avatar zhuwenwen
Browse files

update mla interface

parent 4599e05f
...@@ -29,4 +29,4 @@ triton == 3.3.0 ...@@ -29,4 +29,4 @@ triton == 3.3.0
flash_attn == 2.6.1 flash_attn == 2.6.1
flash_mla == 1.0.0 flash_mla == 1.0.0
lightop == 0.6.0 lightop == 0.6.0
# lmslim == 0.3.1 lmslim == 0.3.1
...@@ -320,7 +320,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -320,7 +320,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
cache_seqlens=attn_metadata.decode.seq_lens, cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank, head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=scheduler_metadata, tile_scheduler_metadata=scheduler_metadata,
num_splits=scheduler_metadata.num_splits, # num_splits=scheduler_metadata.num_splits,
num_splits=None,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
is_fp8_kvcache=False, is_fp8_kvcache=False,
......
...@@ -13,13 +13,6 @@ if current_platform.is_cuda(): ...@@ -13,13 +13,6 @@ if current_platform.is_cuda():
try: try:
import vllm._flashmla_C # noqa: F401 import vllm._flashmla_C # noqa: F401
_flashmla_C_AVAILABLE = True
except ImportError:
_flashmla_C_AVAILABLE = False
elif current_platform.is_rocm():
try:
import flash_mla_cuda # noqa: F401
_flashmla_C_AVAILABLE = True _flashmla_C_AVAILABLE = True
except ImportError: except ImportError:
_flashmla_C_AVAILABLE = False _flashmla_C_AVAILABLE = False
...@@ -30,13 +23,6 @@ if current_platform.is_cuda(): ...@@ -30,13 +23,6 @@ if current_platform.is_cuda():
try: try:
import vllm._flashmla_extension_C # noqa: F401 import vllm._flashmla_extension_C # noqa: F401
_flashmla_extension_C_AVAILABLE = True
except ImportError:
_flashmla_extension_C_AVAILABLE = False
elif current_platform.is_rocm():
try:
import flash_mla_cuda # noqa: F401
_flashmla_extension_C_AVAILABLE = True _flashmla_extension_C_AVAILABLE = True
except ImportError: except ImportError:
_flashmla_extension_C_AVAILABLE = False _flashmla_extension_C_AVAILABLE = False
...@@ -44,6 +30,12 @@ else: ...@@ -44,6 +30,12 @@ else:
_flashmla_extension_C_AVAILABLE = False _flashmla_extension_C_AVAILABLE = False
if current_platform.is_rocm():
import flash_mla.cuda as flash_mla_cuda
_flashmla_C_AVAILABLE = True
_flashmla_extension_C_AVAILABLE = True
def _is_flashmla_available() -> tuple[bool, str | None]: def _is_flashmla_available() -> tuple[bool, str | None]:
if not _flashmla_C_AVAILABLE: if not _flashmla_C_AVAILABLE:
return ( return (
...@@ -100,7 +92,7 @@ def _raise_flashmla_unavailable(*_args, **_kwargs): ...@@ -100,7 +92,7 @@ def _raise_flashmla_unavailable(*_args, **_kwargs):
if _is_flashmla_available()[0]: if _is_flashmla_available()[0]:
if current_platform.is_rocm(): if current_platform.is_rocm():
from flash_mla.flash_mla_interface import ( # noqa: F401 from flash_mla.flash_mla_interface import ( # noqa: F401
FlashMLASchedMeta, # need new flashmla FlashMLASchedMeta,
# flash_attn_varlen_func, # flash_attn_varlen_func,
# flash_attn_varlen_kvpacked_func, # flash_attn_varlen_kvpacked_func,
# flash_attn_varlen_qkvpacked_func, # flash_attn_varlen_qkvpacked_func,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment