"vscode:/vscode.git/clone" did not exist on "c3845d82dc3d1831714898114f87d9c103e2dd41"
Commit 1acf2d7a authored by zhuwenwen's avatar zhuwenwen
Browse files

update get_mla_decoding_metadata_dense_fp8 interface and _k_scale&_v_scale

parent 77210184
...@@ -17,6 +17,7 @@ from vllm.attention.backends.mla.common import (MLACommonBackend, ...@@ -17,6 +17,7 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata, get_mla_metadata,
flash_mla_with_kvcache_fp8, flash_mla_with_kvcache_fp8,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported) is_flashmla_supported)
from vllm import envs from vllm import envs
...@@ -89,6 +90,15 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -89,6 +90,15 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
batch_size) batch_size)
if m.num_decode_tokens > 0: if m.num_decode_tokens > 0:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
get_mla_decoding_metadata_dense_fp8(
m.seq_lens_tensor[m.num_prefills:],
self.num_q_heads,
1, # MQA for the decode path
16,
)
else:
m.decode_tile_scheduler_metadata, m.decode_num_splits = \ m.decode_tile_scheduler_metadata, m.decode_num_splits = \
get_mla_metadata( get_mla_metadata(
m.seq_lens_tensor[m.num_prefills:], m.seq_lens_tensor[m.num_prefills:],
...@@ -109,6 +119,16 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]): ...@@ -109,6 +119,16 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
@contextmanager @contextmanager
def graph_capture(self, max_batch_size: int): def graph_capture(self, max_batch_size: int):
# Run a dummy `get_mla_metadata` so we can get the right shapes # Run a dummy `get_mla_metadata` so we can get the right shapes
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
self._graph_decoder_tile_scheduler_metadata, \
self._graph_decode_num_splits = get_mla_decoding_metadata_dense_fp8(
torch.ones(
max_batch_size, dtype=torch.int32, device=self.runner.device),
self.num_q_heads,
1, # MQA for the decode path
16,
)
else:
self._graph_decoder_tile_scheduler_metadata, \ self._graph_decoder_tile_scheduler_metadata, \
self._graph_decode_num_splits = get_mla_metadata( self._graph_decode_num_splits = get_mla_metadata(
torch.ones( torch.ones(
...@@ -129,6 +149,14 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]): ...@@ -129,6 +149,14 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
batch_size, is_encoder_decoder_model) batch_size, is_encoder_decoder_model)
assert metadata.num_decode_tokens > 0 assert metadata.num_decode_tokens > 0
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_decoding_metadata_dense_fp8(
self._graph_seq_lens[:batch_size],
self.num_q_heads,
1, # MQA for the decode path
16,
)
else:
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata(
self._graph_seq_lens[:batch_size], self._graph_seq_lens[:batch_size],
self.num_q_heads, self.num_q_heads,
......
...@@ -98,6 +98,10 @@ class Attention(nn.Module): ...@@ -98,6 +98,10 @@ class Attention(nn.Module):
# with the model weights. # with the model weights.
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales self.calculate_kv_scales = calculate_kv_scales
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
self._k_scale = torch.ones((1), dtype=torch.float32)
self._v_scale = torch.ones((1), dtype=torch.float32)
else:
self._k_scale = torch.tensor(1.0, dtype=torch.float32) self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32) self._v_scale = torch.tensor(1.0, dtype=torch.float32)
# FlashAttn doesn't support quantizing the kv-cache only # FlashAttn doesn't support quantizing the kv-cache only
......
...@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, ...@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe, flash_mla_with_kvcache_q_nope_pe,
get_mla_metadata, get_mla_metadata,
flash_mla_with_kvcache_fp8, flash_mla_with_kvcache_fp8,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported) is_flashmla_supported)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend, from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
...@@ -72,6 +73,15 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -72,6 +73,15 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def _build_decode(self, block_table_tensor: torch.Tensor, def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
tile_scheduler_metadata, num_splits = \
get_mla_decoding_metadata_dense_fp8(
seq_lens,
self.num_q_heads,
1, # MQA for the decode path
16,
)
else:
tile_scheduler_metadata, num_splits = \ tile_scheduler_metadata, num_splits = \
get_mla_metadata( get_mla_metadata(
seq_lens, seq_lens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment