Commit 1acf2d7a authored by zhuwenwen's avatar zhuwenwen
Browse files

update get_mla_decoding_metadata_dense_fp8 interface and _k_scale&_v_scale

parent 77210184
...@@ -17,6 +17,7 @@ from vllm.attention.backends.mla.common import (MLACommonBackend, ...@@ -17,6 +17,7 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata, get_mla_metadata,
flash_mla_with_kvcache_fp8, flash_mla_with_kvcache_fp8,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported) is_flashmla_supported)
from vllm import envs from vllm import envs
...@@ -89,12 +90,21 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -89,12 +90,21 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
batch_size) batch_size)
if m.num_decode_tokens > 0: if m.num_decode_tokens > 0:
m.decode_tile_scheduler_metadata, m.decode_num_splits = \ if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
get_mla_metadata( m.decode_tile_scheduler_metadata, m.decode_num_splits = \
m.seq_lens_tensor[m.num_prefills:], get_mla_decoding_metadata_dense_fp8(
self.num_q_heads, m.seq_lens_tensor[m.num_prefills:],
1, # MQA for the decode path self.num_q_heads,
) 1, # MQA for the decode path
16,
)
else:
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
get_mla_metadata(
m.seq_lens_tensor[m.num_prefills:],
self.num_q_heads,
1, # MQA for the decode path
)
return m return m
...@@ -109,13 +119,23 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]): ...@@ -109,13 +119,23 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
@contextmanager @contextmanager
def graph_capture(self, max_batch_size: int): def graph_capture(self, max_batch_size: int):
# Run a dummy `get_mla_metadata` so we can get the right shapes # Run a dummy `get_mla_metadata` so we can get the right shapes
self._graph_decoder_tile_scheduler_metadata, \ if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
self._graph_decode_num_splits = get_mla_metadata( self._graph_decoder_tile_scheduler_metadata, \
torch.ones( self._graph_decode_num_splits = get_mla_decoding_metadata_dense_fp8(
max_batch_size, dtype=torch.int32, device=self.runner.device), torch.ones(
self.num_q_heads, max_batch_size, dtype=torch.int32, device=self.runner.device),
1, # MQA for the decode path self.num_q_heads,
) 1, # MQA for the decode path
16,
)
else:
self._graph_decoder_tile_scheduler_metadata, \
self._graph_decode_num_splits = get_mla_metadata(
torch.ones(
max_batch_size, dtype=torch.int32, device=self.runner.device),
self.num_q_heads,
1, # MQA for the decode path
)
with super().graph_capture(max_batch_size): with super().graph_capture(max_batch_size):
yield yield
...@@ -129,11 +149,19 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]): ...@@ -129,11 +149,19 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
batch_size, is_encoder_decoder_model) batch_size, is_encoder_decoder_model)
assert metadata.num_decode_tokens > 0 assert metadata.num_decode_tokens > 0
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
self._graph_seq_lens[:batch_size], decoder_tile_scheduler_metadata, decode_num_splits = get_mla_decoding_metadata_dense_fp8(
self.num_q_heads, self._graph_seq_lens[:batch_size],
1, # MQA for the decode path self.num_q_heads,
) 1, # MQA for the decode path
16,
)
else:
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata(
self._graph_seq_lens[:batch_size],
self.num_q_heads,
1, # MQA for the decode path
)
self._graph_decoder_tile_scheduler_metadata.copy_( self._graph_decoder_tile_scheduler_metadata.copy_(
decoder_tile_scheduler_metadata) decoder_tile_scheduler_metadata)
......
...@@ -98,8 +98,12 @@ class Attention(nn.Module): ...@@ -98,8 +98,12 @@ class Attention(nn.Module):
# with the model weights. # with the model weights.
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32) if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
self._v_scale = torch.tensor(1.0, dtype=torch.float32) self._k_scale = torch.ones((1), dtype=torch.float32)
self._v_scale = torch.ones((1), dtype=torch.float32)
else:
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
# FlashAttn doesn't support quantizing the kv-cache only # FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well. # but requires q to be quantized as well.
self._q_scale = torch.tensor(1.0, dtype=torch.float32) self._q_scale = torch.tensor(1.0, dtype=torch.float32)
......
...@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, ...@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe, flash_mla_with_kvcache_q_nope_pe,
get_mla_metadata, get_mla_metadata,
flash_mla_with_kvcache_fp8, flash_mla_with_kvcache_fp8,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported) is_flashmla_supported)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend, from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
...@@ -72,12 +73,21 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -72,12 +73,21 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def _build_decode(self, block_table_tensor: torch.Tensor, def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \ if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
get_mla_metadata( tile_scheduler_metadata, num_splits = \
seq_lens, get_mla_decoding_metadata_dense_fp8(
self.num_q_heads, seq_lens,
1, # MQA for the decode path self.num_q_heads,
) 1, # MQA for the decode path
16,
)
else:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens,
self.num_q_heads,
1, # MQA for the decode path
)
if self.runner.full_cuda_graph: if self.runner.full_cuda_graph:
# First time around (CUDAGraph capture), allocate the static buffer # First time around (CUDAGraph capture), allocate the static buffer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment