Unverified Commit d9f83d62 authored by Sage Moore's avatar Sage Moore Committed by GitHub
Browse files

[ROCm] Enable chunked prefill/paged attention in MLA on ROCm (#14316)


Signed-off-by: default avatarSage Moore <sage@neuralmagic.com>
parent 4a754fcf
...@@ -1327,21 +1327,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1327,21 +1327,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
[0, q.shape[-1] - v.shape[-1]], [0, q.shape[-1] - v.shape[-1]],
value=0) value=0)
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN: if is_vllm_fa:
attn_output, attn_softmax_lse = self.triton_fa_func(
q,
k,
v_padded,
None,
prefill_metadata.query_start_loc,
prefill_metadata.context_chunk_cu_seq_lens[i],
prefill_metadata.max_query_len,
prefill_metadata.context_chunk_max_seq_lens[i],
False, # causal
self.scale,
None, # attn_mask is None unless applying ALiBi mask
)
elif is_vllm_fa:
attn_output, attn_softmax_lse = self.flash_attn_varlen_func( attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
q=q, q=q,
k=k, k=k,
...@@ -1416,7 +1402,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1416,7 +1402,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0) value=0)
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN: if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context:
output = self.triton_fa_func( output = self.triton_fa_func(
q, q,
k, k,
......
...@@ -3450,9 +3450,9 @@ class VllmConfig: ...@@ -3450,9 +3450,9 @@ class VllmConfig:
self.compilation_config.level = CompilationLevel.NO_COMPILATION self.compilation_config.level = CompilationLevel.NO_COMPILATION
if self.model_config and self.model_config.use_mla and \ if self.model_config and self.model_config.use_mla and \
not current_platform.is_cuda(): not (current_platform.is_cuda() or current_platform.is_rocm()):
logger.info( logger.info(
"MLA is enabled on a non-cuda platform; forcing chunked " "MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.") "prefill and prefix caching to be disabled.")
self.scheduler_config.enable_chunked_prefill = False self.scheduler_config.enable_chunked_prefill = False
self.scheduler_config.chunked_prefill_enabled = False self.scheduler_config.chunked_prefill_enabled = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment