Unverified Commit b10f41c8 authored by Pavani Majety's avatar Pavani Majety Committed by GitHub
Browse files

[SM100] Enable fp8 compute for prefill MLA (#30746)


Signed-off-by: default avatarPavani Majety <pmajety@nvidia.com>
parent 7b926e89
...@@ -27,7 +27,7 @@ from vllm.utils.math_utils import cdiv ...@@ -27,7 +27,7 @@ from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.mla.common import QueryLenSupport from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import MLAAttentionSpec
BACKENDS_TO_TEST = [ BACKENDS_TO_TEST = [
AttentionBackendEnum.CUTLASS_MLA, AttentionBackendEnum.CUTLASS_MLA,
...@@ -289,7 +289,7 @@ class MockMLAAttentionLayer(AttentionLayerBase): ...@@ -289,7 +289,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
def run_attention_backend( def run_attention_backend(
backend: AttentionBackendEnum, backend: AttentionBackendEnum,
kv_cache_spec: FullAttentionSpec, kv_cache_spec: MLAAttentionSpec,
layer_names: list[str], layer_names: list[str],
vllm_config, vllm_config,
device: torch.device, device: torch.device,
...@@ -740,7 +740,7 @@ def test_backend_correctness( ...@@ -740,7 +740,7 @@ def test_backend_correctness(
kv_cache = kv_cache_per_block_size[block_size] kv_cache = kv_cache_per_block_size[block_size]
# Create kv_cache_spec with the correct block_size for this backend # Create kv_cache_spec with the correct block_size for this backend
backend_kv_cache_spec = FullAttentionSpec( backend_kv_cache_spec = MLAAttentionSpec(
block_size=block_size, block_size=block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads( num_kv_heads=vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config vllm_config.parallel_config
...@@ -748,6 +748,7 @@ def test_backend_correctness( ...@@ -748,6 +748,7 @@ def test_backend_correctness(
head_size=vllm_config.model_config.get_head_size(), head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype, dtype=vllm_config.model_config.dtype,
sliding_window=vllm_config.model_config.get_sliding_window(), sliding_window=vllm_config.model_config.get_sliding_window(),
cache_dtype_str=vllm_config.cache_config.cache_dtype,
) )
backend_output = run_attention_backend( backend_output = run_attention_backend(
......
...@@ -325,7 +325,6 @@ def flashinfer_trtllm_fp4_moe( ...@@ -325,7 +325,6 @@ def flashinfer_trtllm_fp4_moe(
local_expert_offset=layer.ep_rank * layer.local_num_experts, local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts, local_num_experts=layer.local_num_experts,
routed_scaling_factor=None, routed_scaling_factor=None,
tile_tokens_dim=None,
routing_method_type=routing_method_type, routing_method_type=routing_method_type,
do_finalize=True, do_finalize=True,
)[0] )[0]
......
...@@ -355,6 +355,7 @@ class MLACommonPrefillMetadata: ...@@ -355,6 +355,7 @@ class MLACommonPrefillMetadata:
max_query_len: int max_query_len: int
chunked_context: ChunkedContextMetadata | None = None chunked_context: ChunkedContextMetadata | None = None
query_seq_lens: torch.Tensor | None = None query_seq_lens: torch.Tensor | None = None
q_data_type: torch.dtype | None = None
@dataclass @dataclass
...@@ -539,6 +540,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -539,6 +540,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
metadata_cls if metadata_cls is not None else MLACommonMetadata metadata_cls if metadata_cls is not None else MLACommonMetadata
) )
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
self.q_data_type = (
current_platform.fp8_dtype()
if (kv_cache_spec is not None and "fp8" in kv_cache_spec.cache_dtype_str)
else vllm_config.model_config.dtype
)
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
...@@ -681,7 +687,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -681,7 +687,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# For main run, qo_indptr == kv_indptr # For main run, qo_indptr == kv_indptr
kv_indptr = qo_indptr.clone() kv_indptr = qo_indptr.clone()
# Prepare main prefill # Prepare main prefill
self._fi_prefill_main.plan( self._fi_prefill_main.plan(
qo_indptr=qo_indptr, qo_indptr=qo_indptr,
...@@ -694,7 +699,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -694,7 +699,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale, sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left, window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap, logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.model_config.dtype, q_data_type=self.q_data_type,
) )
# Prepare context prefills # Prepare context prefills
...@@ -713,7 +718,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -713,7 +718,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale, sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left, window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap, logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.model_config.dtype, q_data_type=self.q_data_type,
) )
prefill.prefill_main = self._fi_prefill_main prefill.prefill_main = self._fi_prefill_main
...@@ -970,6 +975,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -970,6 +975,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc=prefill_query_start_loc, query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len, max_query_len=max_query_len,
chunked_context=chunked_context_metadata, chunked_context=chunked_context_metadata,
q_data_type=self.q_data_type,
) )
if self._use_cudnn_prefill: if self._use_cudnn_prefill:
...@@ -1370,8 +1376,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1370,8 +1376,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return attn_out return attn_out
def _run_prefill_new_tokens_fa( def _run_prefill_new_tokens_fa(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
): ):
logger.debug_once("Running FlashAttention prefill new tokens", scope="local")
return self._flash_attn_varlen_diff_headdims( return self._flash_attn_varlen_diff_headdims(
q=q, q=q,
k=k, k=k,
...@@ -1386,11 +1399,23 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1386,11 +1399,23 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
) )
def _run_prefill_new_tokens_fi( def _run_prefill_new_tokens_fi(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
): ):
logger.debug_once("Running FlashInfer prefill new tokens", scope="local")
assert isinstance(prefill, FlashInferPrefillMetadata) assert isinstance(prefill, FlashInferPrefillMetadata)
assert prefill.prefill_main is not None assert prefill.prefill_main is not None
if fp8_attention:
logger.debug_once("Running Flashinfer prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
ret = prefill.prefill_main.run( ret = prefill.prefill_main.run(
q=q, q=q,
k=k, k=k,
...@@ -1403,10 +1428,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1403,10 +1428,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return ret return ret
def _run_prefill_new_tokens_cudnn( def _run_prefill_new_tokens_cudnn(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
): ):
logger.debug_once("Running Cudnn prefill new tokens", scope="local")
assert isinstance(prefill, CudnnPrefillMetadata) assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.query_seq_lens is not None assert prefill.query_seq_lens is not None
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
output, lse = cudnn_batch_prefill_with_kv_cache( output, lse = cudnn_batch_prefill_with_kv_cache(
q=q, q=q,
k_cache=k, k_cache=k,
...@@ -1428,9 +1461,19 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1428,9 +1461,19 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return output return output
def _run_prefill_context_chunk_fa( def _run_prefill_context_chunk_fa(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
): ):
logger.debug_once("Running FlashAttention prefill context chunk", scope="local")
assert prefill.chunked_context is not None assert prefill.chunked_context is not None
assert fp8_attention is False, (
"FlashAttention prefill does not support fp8 attention"
)
return self._flash_attn_varlen_diff_headdims( return self._flash_attn_varlen_diff_headdims(
q=q, q=q,
k=k, k=k,
...@@ -1445,10 +1488,22 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1445,10 +1488,22 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
) )
def _run_prefill_context_chunk_fi( def _run_prefill_context_chunk_fi(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
): ):
logger.debug_once("Running FlashInfer prefill context chunk", scope="local")
assert isinstance(prefill, FlashInferPrefillMetadata) assert isinstance(prefill, FlashInferPrefillMetadata)
if fp8_attention:
logger.debug_once("Running FlashInfer prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
attn_out, lse = prefill.prefill_chunks[chunk_idx].run( attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
q=q, q=q,
k=k, k=k,
...@@ -1460,12 +1515,20 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1460,12 +1515,20 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return attn_out, lse.transpose(0, 1).contiguous() return attn_out, lse.transpose(0, 1).contiguous()
def _run_prefill_context_chunk_cudnn( def _run_prefill_context_chunk_cudnn(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
): ):
logger.debug_once("Running Cudnn prefill context chunk", scope="local")
assert isinstance(prefill, CudnnPrefillMetadata) assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.chunked_context is not None assert prefill.chunked_context is not None
assert prefill.chunked_context.seq_lens[chunk_idx] is not None assert prefill.chunked_context.seq_lens[chunk_idx] is not None
assert prefill.query_seq_lens is not None assert prefill.query_seq_lens is not None
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
return cudnn_batch_prefill_with_kv_cache( return cudnn_batch_prefill_with_kv_cache(
q=q, q=q,
k_cache=k, k_cache=k,
...@@ -1485,13 +1548,27 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1485,13 +1548,27 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
) )
def _run_prefill_new_tokens_trtllm_ragged( def _run_prefill_new_tokens_trtllm_ragged(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
): ):
logger.debug_once("Running TRT-LLM ragged prefill new tokens", scope="local")
"""TRT-LLM ragged attention for new tokens (causal).""" """TRT-LLM ragged attention for new tokens (causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek from flashinfer.prefill import trtllm_ragged_attention_deepseek
assert prefill.query_seq_lens is not None assert prefill.query_seq_lens is not None
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
ret = trtllm_ragged_attention_deepseek( ret = trtllm_ragged_attention_deepseek(
query=q, query=q,
key=k, key=k,
...@@ -1518,8 +1595,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1518,8 +1595,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return ret return ret
def _run_prefill_context_chunk_trtllm_ragged( def _run_prefill_context_chunk_trtllm_ragged(
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
): ):
logger.debug_once("Running TRT-LLM ragged prefill context chunk", scope="local")
"""TRT-LLM ragged attention for context chunks (non-causal).""" """TRT-LLM ragged attention for context chunks (non-causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek from flashinfer.prefill import trtllm_ragged_attention_deepseek
...@@ -1535,6 +1619,13 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1535,6 +1619,13 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
) )
self._workspace_buffer.fill_(0) self._workspace_buffer.fill_(0)
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
attn_out, lse = trtllm_ragged_attention_deepseek( attn_out, lse = trtllm_ragged_attention_deepseek(
query=q, query=q,
key=k, key=k,
...@@ -1687,6 +1778,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1687,6 +1778,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor, k_scale: torch.Tensor,
fp8_attention: bool,
): ):
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill prefill_metadata = attn_metadata.prefill
...@@ -1725,6 +1817,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1725,6 +1817,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q=q, q=q,
k=k, k=k,
v=v, v=v,
fp8_attention=fp8_attention,
) )
if output is None: if output is None:
...@@ -1753,6 +1846,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1753,6 +1846,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor, k_scale: torch.Tensor,
dcp_world_size: int, dcp_world_size: int,
fp8_attention: bool,
): ):
assert k_scale is None, "DCP not support scaled kvcache now." assert k_scale is None, "DCP not support scaled kvcache now."
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
...@@ -1829,6 +1923,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1829,6 +1923,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q=q, q=q,
k=k, k=k,
v=v, v=v,
fp8_attention=fp8_attention,
) )
if output is None: if output is None:
...@@ -1859,6 +1954,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1859,6 +1954,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor, k_scale: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
fp8_attention: bool = False,
) -> None: ) -> None:
# TODO (zyongye): Prefill function here # TODO (zyongye): Prefill function here
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
...@@ -1878,6 +1974,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1878,6 +1974,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k=k, k=k,
v=v, v=v,
return_softmax_lse=has_context, return_softmax_lse=has_context,
fp8_attention=fp8_attention,
) )
if has_context: if has_context:
...@@ -1890,11 +1987,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1890,11 +1987,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata, attn_metadata,
k_scale=None, k_scale=None,
dcp_world_size=self.dcp_world_size, dcp_world_size=self.dcp_world_size,
fp8_attention=fp8_attention,
) )
) )
else: else:
context_output, context_lse = self._compute_prefill_context( context_output, context_lse = self._compute_prefill_context(
q, kv_c_and_k_pe_cache, attn_metadata, k_scale q, kv_c_and_k_pe_cache, attn_metadata, k_scale, fp8_attention
) )
# unpad if necessary # unpad if necessary
...@@ -2015,6 +2113,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -2015,6 +2113,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata, attn_metadata,
layer._k_scale, layer._k_scale,
output=output[num_decode_tokens:], output=output[num_decode_tokens:],
fp8_attention=fp8_attention,
) )
if has_decode: if has_decode:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment