Unverified Commit 3e102623 authored by Pavani Majety's avatar Pavani Majety Committed by GitHub
Browse files

Revert "[SM100] Enable fp8 compute for prefill MLA (#30746)" (#31197)


Signed-off-by: default avatarPavani Majety <pmajety@nvidia.com>
parent 612d5ffd
...@@ -27,7 +27,7 @@ from vllm.utils.math_utils import cdiv ...@@ -27,7 +27,7 @@ from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.mla.common import QueryLenSupport from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import MLAAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [ BACKENDS_TO_TEST = [
AttentionBackendEnum.CUTLASS_MLA, AttentionBackendEnum.CUTLASS_MLA,
...@@ -289,7 +289,7 @@ class MockMLAAttentionLayer(AttentionLayerBase): ...@@ -289,7 +289,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
def run_attention_backend( def run_attention_backend(
backend: AttentionBackendEnum, backend: AttentionBackendEnum,
kv_cache_spec: MLAAttentionSpec, kv_cache_spec: FullAttentionSpec,
layer_names: list[str], layer_names: list[str],
vllm_config, vllm_config,
device: torch.device, device: torch.device,
...@@ -740,7 +740,7 @@ def test_backend_correctness( ...@@ -740,7 +740,7 @@ def test_backend_correctness(
kv_cache = kv_cache_per_block_size[block_size] kv_cache = kv_cache_per_block_size[block_size]
# Create kv_cache_spec with the correct block_size for this backend # Create kv_cache_spec with the correct block_size for this backend
backend_kv_cache_spec = MLAAttentionSpec( backend_kv_cache_spec = FullAttentionSpec(
block_size=block_size, block_size=block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads( num_kv_heads=vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config vllm_config.parallel_config
...@@ -748,7 +748,6 @@ def test_backend_correctness( ...@@ -748,7 +748,6 @@ def test_backend_correctness(
head_size=vllm_config.model_config.get_head_size(), head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype, dtype=vllm_config.model_config.dtype,
sliding_window=vllm_config.model_config.get_sliding_window(), sliding_window=vllm_config.model_config.get_sliding_window(),
cache_dtype_str=vllm_config.cache_config.cache_dtype,
) )
backend_output = run_attention_backend( backend_output = run_attention_backend(
......
...@@ -325,6 +325,7 @@ def flashinfer_trtllm_fp4_moe( ...@@ -325,6 +325,7 @@ def flashinfer_trtllm_fp4_moe(
local_expert_offset=layer.ep_rank * layer.local_num_experts, local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts, local_num_experts=layer.local_num_experts,
routed_scaling_factor=None, routed_scaling_factor=None,
tile_tokens_dim=None,
routing_method_type=routing_method_type, routing_method_type=routing_method_type,
do_finalize=True, do_finalize=True,
)[0] )[0]
......
...@@ -541,11 +541,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -541,11 +541,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
metadata_cls if metadata_cls is not None else MLACommonMetadata metadata_cls if metadata_cls is not None else MLACommonMetadata
) )
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
self.q_data_type = (
current_platform.fp8_dtype()
if (kv_cache_spec is not None and "fp8" in kv_cache_spec.cache_dtype_str)
else vllm_config.model_config.dtype
)
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
...@@ -689,6 +684,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -689,6 +684,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# For main run, qo_indptr == kv_indptr # For main run, qo_indptr == kv_indptr
kv_indptr = qo_indptr.clone() kv_indptr = qo_indptr.clone()
# Prepare main prefill # Prepare main prefill
self._fi_prefill_main.plan( self._fi_prefill_main.plan(
qo_indptr=qo_indptr, qo_indptr=qo_indptr,
...@@ -701,7 +697,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -701,7 +697,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale, sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left, window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap, logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.q_data_type, q_data_type=self.model_config.dtype,
) )
# Prepare context prefills # Prepare context prefills
...@@ -720,7 +716,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -720,7 +716,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale, sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left, window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap, logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.q_data_type, q_data_type=self.model_config.dtype,
) )
prefill.prefill_main = self._fi_prefill_main prefill.prefill_main = self._fi_prefill_main
...@@ -973,7 +969,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -973,7 +969,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc=prefill_query_start_loc, query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len, max_query_len=max_query_len,
chunked_context=chunked_context_metadata, chunked_context=chunked_context_metadata,
q_data_type=self.q_data_type,
) )
if self._use_cudnn_prefill: if self._use_cudnn_prefill:
...@@ -1384,15 +1379,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1384,15 +1379,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return attn_out return attn_out
def _run_prefill_new_tokens_fa( def _run_prefill_new_tokens_fa(
self, self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
): ):
logger.debug_once("Running FlashAttention prefill new tokens", scope="local")
return self._flash_attn_varlen_diff_headdims( return self._flash_attn_varlen_diff_headdims(
q=q, q=q,
k=k, k=k,
...@@ -1407,23 +1395,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1407,23 +1395,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
) )
def _run_prefill_new_tokens_fi( def _run_prefill_new_tokens_fi(
self, self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
): ):
logger.debug_once("Running FlashInfer prefill new tokens", scope="local")
assert isinstance(prefill, FlashInferPrefillMetadata) assert isinstance(prefill, FlashInferPrefillMetadata)
assert prefill.prefill_main is not None assert prefill.prefill_main is not None
if fp8_attention:
logger.debug_once("Running Flashinfer prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
ret = prefill.prefill_main.run( ret = prefill.prefill_main.run(
q=q, q=q,
k=k, k=k,
...@@ -1436,18 +1412,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1436,18 +1412,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return ret return ret
def _run_prefill_new_tokens_cudnn( def _run_prefill_new_tokens_cudnn(
self, self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
): ):
logger.debug_once("Running Cudnn prefill new tokens", scope="local")
assert isinstance(prefill, CudnnPrefillMetadata) assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.query_seq_lens is not None assert prefill.query_seq_lens is not None
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
output, lse = cudnn_batch_prefill_with_kv_cache( output, lse = cudnn_batch_prefill_with_kv_cache(
q=q, q=q,
k_cache=k, k_cache=k,
...@@ -1469,19 +1437,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1469,19 +1437,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return output return output
def _run_prefill_context_chunk_fa( def _run_prefill_context_chunk_fa(
self, self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
): ):
logger.debug_once("Running FlashAttention prefill context chunk", scope="local")
assert prefill.chunked_context is not None assert prefill.chunked_context is not None
assert fp8_attention is False, (
"FlashAttention prefill does not support fp8 attention"
)
return self._flash_attn_varlen_diff_headdims( return self._flash_attn_varlen_diff_headdims(
q=q, q=q,
k=k, k=k,
...@@ -1496,22 +1454,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1496,22 +1454,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
) )
def _run_prefill_context_chunk_fi( def _run_prefill_context_chunk_fi(
self, self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
): ):
logger.debug_once("Running FlashInfer prefill context chunk", scope="local")
assert isinstance(prefill, FlashInferPrefillMetadata) assert isinstance(prefill, FlashInferPrefillMetadata)
if fp8_attention:
logger.debug_once("Running FlashInfer prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
attn_out, lse = prefill.prefill_chunks[chunk_idx].run( attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
q=q, q=q,
k=k, k=k,
...@@ -1523,20 +1469,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1523,20 +1469,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return attn_out, lse.transpose(0, 1).contiguous() return attn_out, lse.transpose(0, 1).contiguous()
def _run_prefill_context_chunk_cudnn( def _run_prefill_context_chunk_cudnn(
self, self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
): ):
logger.debug_once("Running Cudnn prefill context chunk", scope="local")
assert isinstance(prefill, CudnnPrefillMetadata) assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.chunked_context is not None assert prefill.chunked_context is not None
assert prefill.chunked_context.seq_lens[chunk_idx] is not None assert prefill.chunked_context.seq_lens[chunk_idx] is not None
assert prefill.query_seq_lens is not None assert prefill.query_seq_lens is not None
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
return cudnn_batch_prefill_with_kv_cache( return cudnn_batch_prefill_with_kv_cache(
q=q, q=q,
k_cache=k, k_cache=k,
...@@ -1556,28 +1494,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1556,28 +1494,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
) )
def _run_prefill_new_tokens_trtllm_ragged( def _run_prefill_new_tokens_trtllm_ragged(
self, self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
): ):
logger.debug_once("Running TRT-LLM ragged prefill new tokens", scope="local")
"""TRT-LLM ragged attention for new tokens (causal).""" """TRT-LLM ragged attention for new tokens (causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek from flashinfer.prefill import trtllm_ragged_attention_deepseek
assert prefill.query_seq_lens is not None assert prefill.query_seq_lens is not None
assert prefill.workspace_buffer is not None assert prefill.workspace_buffer is not None
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
ret = trtllm_ragged_attention_deepseek( ret = trtllm_ragged_attention_deepseek(
query=q, query=q,
key=k, key=k,
...@@ -1604,15 +1528,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1604,15 +1528,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return ret return ret
def _run_prefill_context_chunk_trtllm_ragged( def _run_prefill_context_chunk_trtllm_ragged(
self, self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
): ):
logger.debug_once("Running TRT-LLM ragged prefill context chunk", scope="local")
"""TRT-LLM ragged attention for context chunks (non-causal).""" """TRT-LLM ragged attention for context chunks (non-causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek from flashinfer.prefill import trtllm_ragged_attention_deepseek
...@@ -1629,13 +1546,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1629,13 +1546,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
) )
prefill.workspace_buffer.fill_(0) prefill.workspace_buffer.fill_(0)
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
attn_out, lse = trtllm_ragged_attention_deepseek( attn_out, lse = trtllm_ragged_attention_deepseek(
query=q, query=q,
key=k, key=k,
...@@ -1788,7 +1698,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1788,7 +1698,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor, k_scale: torch.Tensor,
fp8_attention: bool,
): ):
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill prefill_metadata = attn_metadata.prefill
...@@ -1827,7 +1736,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1827,7 +1736,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q=q, q=q,
k=k, k=k,
v=v, v=v,
fp8_attention=fp8_attention,
) )
if output is None: if output is None:
...@@ -1856,7 +1764,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1856,7 +1764,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor, k_scale: torch.Tensor,
dcp_world_size: int, dcp_world_size: int,
fp8_attention: bool,
): ):
assert k_scale is None, "DCP not support scaled kvcache now." assert k_scale is None, "DCP not support scaled kvcache now."
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
...@@ -1933,7 +1840,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1933,7 +1840,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q=q, q=q,
k=k, k=k,
v=v, v=v,
fp8_attention=fp8_attention,
) )
if output is None: if output is None:
...@@ -1964,7 +1870,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1964,7 +1870,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor, k_scale: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
fp8_attention: bool = False,
) -> None: ) -> None:
# TODO (zyongye): Prefill function here # TODO (zyongye): Prefill function here
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
...@@ -1984,7 +1889,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1984,7 +1889,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k=k, k=k,
v=v, v=v,
return_softmax_lse=has_context, return_softmax_lse=has_context,
fp8_attention=fp8_attention,
) )
if has_context: if has_context:
...@@ -1997,12 +1901,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1997,12 +1901,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata, attn_metadata,
k_scale=None, k_scale=None,
dcp_world_size=self.dcp_world_size, dcp_world_size=self.dcp_world_size,
fp8_attention=fp8_attention,
) )
) )
else: else:
context_output, context_lse = self._compute_prefill_context( context_output, context_lse = self._compute_prefill_context(
q, kv_c_and_k_pe_cache, attn_metadata, k_scale, fp8_attention q, kv_c_and_k_pe_cache, attn_metadata, k_scale
) )
# unpad if necessary # unpad if necessary
...@@ -2123,7 +2026,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -2123,7 +2026,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata, attn_metadata,
layer._k_scale, layer._k_scale,
output=output[num_decode_tokens:], output=output[num_decode_tokens:],
fp8_attention=fp8_attention,
) )
if has_decode: if has_decode:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment