Commit a9f57e73 authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_FLASH_MLA_FP8 to use mla fp8

set VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT=1
parent 8548cf87
...@@ -16,7 +16,10 @@ from vllm.attention.backends.mla.common import (MLACommonBackend, ...@@ -16,7 +16,10 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonState) MLACommonState)
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata, get_mla_metadata,
flash_mla_with_kvcache_fp8,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported) is_flashmla_supported)
from vllm import envs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
...@@ -87,13 +90,20 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -87,13 +90,20 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
batch_size) batch_size)
if m.num_decode_tokens > 0: if m.num_decode_tokens > 0:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
get_mla_decoding_metadata_dense_fp8(
m.seq_lens_tensor[m.num_prefills:],
self.num_q_heads,
1, # MQA for the decode path
)
else:
m.decode_tile_scheduler_metadata, m.decode_num_splits = \ m.decode_tile_scheduler_metadata, m.decode_num_splits = \
get_mla_metadata( get_mla_metadata(
m.seq_lens_tensor[m.num_prefills:], m.seq_lens_tensor[m.num_prefills:],
self.num_q_heads, self.num_q_heads,
1, # MQA for the decode path 1, # MQA for the decode path
) )
return m return m
...@@ -108,6 +118,15 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]): ...@@ -108,6 +118,15 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
@contextmanager @contextmanager
def graph_capture(self, max_batch_size: int): def graph_capture(self, max_batch_size: int):
# Run a dummy `get_mla_metadata` so we can get the right shapes # Run a dummy `get_mla_metadata` so we can get the right shapes
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
self._graph_decoder_tile_scheduler_metadata, \
self._graph_decode_num_splits = get_mla_decoding_metadata_dense_fp8(
torch.ones(
max_batch_size, dtype=torch.int32, device=self.runner.device),
self.num_q_heads,
1, # MQA for the decode path
)
else:
self._graph_decoder_tile_scheduler_metadata, \ self._graph_decoder_tile_scheduler_metadata, \
self._graph_decode_num_splits = get_mla_metadata( self._graph_decode_num_splits = get_mla_metadata(
torch.ones( torch.ones(
...@@ -128,6 +147,13 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]): ...@@ -128,6 +147,13 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
batch_size, is_encoder_decoder_model) batch_size, is_encoder_decoder_model)
assert metadata.num_decode_tokens > 0 assert metadata.num_decode_tokens > 0
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_decoding_metadata_dense_fp8(
self._graph_seq_lens[:batch_size],
self.num_q_heads,
1, # MQA for the decode path
)
else:
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata(
self._graph_seq_lens[:batch_size], self._graph_seq_lens[:batch_size],
self.num_q_heads, self.num_q_heads,
...@@ -222,6 +248,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -222,6 +248,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe: torch.Tensor, q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata, attn_metadata: FlashMLAMetadata,
q_scale = None,
k_scale = None, k_scale = None,
kv_cache_dtype = "auto", kv_cache_dtype = "auto",
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -233,6 +260,21 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -233,6 +260,21 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q = torch.cat([q_nope, q_pe], dim=-1)\ q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode) .unsqueeze(1) # Add seqlen dim of 1 (decode)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
o, _ = flash_mla_with_kvcache_fp8(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata,
num_splits=decode_meta.decode_num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=q_scale,
descale_k=k_scale,
)
else:
o, _ = flash_mla_with_kvcache( o, _ = flash_mla_with_kvcache(
q=q, q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
...@@ -246,5 +288,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -246,5 +288,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
k_scale = k_scale, k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype, kv_cache_dtype = kv_cache_dtype,
) )
return self._v_up_proj(o) return self._v_up_proj(o)
...@@ -1404,6 +1404,5 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1404,6 +1404,5 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[num_prefill_tokens:] = self._forward_decode( output[num_prefill_tokens:] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._k_scale, self.kv_cache_dtype) decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._q_scale, layer._k_scale, self.kv_cache_dtype)
return output return output
\ No newline at end of file
...@@ -69,6 +69,27 @@ def get_mla_metadata( ...@@ -69,6 +69,27 @@ def get_mla_metadata(
num_heads_k) num_heads_k)
def get_mla_decoding_metadata_dense_fp8(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Return:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens,
num_heads_per_head_k,
num_heads_k)
def flash_mla_with_kvcache( def flash_mla_with_kvcache(
q: torch.Tensor, q: torch.Tensor,
k_cache: torch.Tensor, k_cache: torch.Tensor,
...@@ -199,6 +220,59 @@ def flash_mla_with_kvcache_q_nope_pe( ...@@ -199,6 +220,59 @@ def flash_mla_with_kvcache_q_nope_pe(
return out, softmax_lse return out, softmax_lse
def flash_mla_with_kvcache_fp8(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head_dim of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
torch.int32, return by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, return by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
return out, softmax_lse
# #
# TODO: Add fake functions # TODO: Add fake functions
# #
......
...@@ -146,6 +146,7 @@ if TYPE_CHECKING: ...@@ -146,6 +146,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_TRITON_OPT_MLA: bool = False VLLM_USE_TRITON_OPT_MLA: bool = False
VLLM_USE_FLASH_MLA: bool = False VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_FLASH_MLA_FP8: bool = False
VLLM_USE_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
VLLM_USE_TC_PAGED_ATTN: bool = False VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False VLLM_USE_PA_PRINT_PARAM: bool = False
...@@ -1038,6 +1039,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1038,6 +1039,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_MLA": "VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))), lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))),
# If set, vLLM will use FLASH MLA fp8 attention optimizations.
"VLLM_USE_FLASH_MLA_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA_FP8", "0"))),
# flag to control vllm to use optimized kernels # flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP": "VLLM_USE_OPT_OP":
lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
......
...@@ -255,8 +255,8 @@ def get_model_architecture( ...@@ -255,8 +255,8 @@ def get_model_architecture(
os.environ['VLLM_USE_CAT_MLA'] = '1' os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"): if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1' os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"): if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
# os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1' os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"): if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1' os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}: if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
...@@ -300,8 +300,8 @@ def get_model_architecture( ...@@ -300,8 +300,8 @@ def get_model_architecture(
os.environ['VLLM_USE_CAT_MLA'] = '1' os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"): if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1' os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"): if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
# os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1' os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"): if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1' os.environ['VLLM_SCHED_ENABLE_MINIMAL_INJECTION'] = '1'
if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}: if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
......
...@@ -1199,6 +1199,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1199,6 +1199,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[:num_decode_tokens] = self._forward_decode( output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._k_scale, self.kv_cache_dtype) decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._q_scale, layer._k_scale, self.kv_cache_dtype)
return output_padded return output_padded
\ No newline at end of file
...@@ -11,6 +11,8 @@ from vllm.attention.backends.abstract import (AttentionType, ...@@ -11,6 +11,8 @@ from vllm.attention.backends.abstract import (AttentionType,
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe, flash_mla_with_kvcache_q_nope_pe,
get_mla_metadata, get_mla_metadata,
flash_mla_with_kvcache_fp8,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported) is_flashmla_supported)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend, from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
...@@ -71,6 +73,14 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -71,6 +73,14 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def _build_decode(self, block_table_tensor: torch.Tensor, def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
tile_scheduler_metadata, num_splits = \
get_mla_decoding_metadata_dense_fp8(
seq_lens,
self.num_q_heads,
1, # MQA for the decode path
)
else:
tile_scheduler_metadata, num_splits = \ tile_scheduler_metadata, num_splits = \
get_mla_metadata( get_mla_metadata(
seq_lens, seq_lens,
...@@ -162,12 +172,42 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -162,12 +172,42 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe: torch.Tensor, q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata, attn_metadata: FlashMLAMetadata,
q_scale = None,
k_scale = None, k_scale = None,
kv_cache_dtype = "auto", kv_cache_dtype = "auto",
) -> torch.Tensor: ) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache_fp8(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=q_scale,
descale_k=k_scale,
)
else:
if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3": if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
if envs.VLLM_USE_OPT_CAT: if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024: if q_nope.shape[0] < 1024:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment