Commit 1acf2d7a authored by zhuwenwen's avatar zhuwenwen
Browse files

update get_mla_decoding_metadata_dense_fp8 interface and _k_scale&_v_scale

parent 77210184
......@@ -17,6 +17,7 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
flash_mla_with_kvcache_fp8,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported)
from vllm import envs
......@@ -89,12 +90,21 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
batch_size)
if m.num_decode_tokens > 0:
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
get_mla_metadata(
m.seq_lens_tensor[m.num_prefills:],
self.num_q_heads,
1, # MQA for the decode path
)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
get_mla_decoding_metadata_dense_fp8(
m.seq_lens_tensor[m.num_prefills:],
self.num_q_heads,
1, # MQA for the decode path
16,
)
else:
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
get_mla_metadata(
m.seq_lens_tensor[m.num_prefills:],
self.num_q_heads,
1, # MQA for the decode path
)
return m
......@@ -109,13 +119,23 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
@contextmanager
def graph_capture(self, max_batch_size: int):
# Run a dummy `get_mla_metadata` so we can get the right shapes
self._graph_decoder_tile_scheduler_metadata, \
self._graph_decode_num_splits = get_mla_metadata(
torch.ones(
max_batch_size, dtype=torch.int32, device=self.runner.device),
self.num_q_heads,
1, # MQA for the decode path
)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
self._graph_decoder_tile_scheduler_metadata, \
self._graph_decode_num_splits = get_mla_decoding_metadata_dense_fp8(
torch.ones(
max_batch_size, dtype=torch.int32, device=self.runner.device),
self.num_q_heads,
1, # MQA for the decode path
16,
)
else:
self._graph_decoder_tile_scheduler_metadata, \
self._graph_decode_num_splits = get_mla_metadata(
torch.ones(
max_batch_size, dtype=torch.int32, device=self.runner.device),
self.num_q_heads,
1, # MQA for the decode path
)
with super().graph_capture(max_batch_size):
yield
......@@ -129,11 +149,19 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
batch_size, is_encoder_decoder_model)
assert metadata.num_decode_tokens > 0
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata(
self._graph_seq_lens[:batch_size],
self.num_q_heads,
1, # MQA for the decode path
)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_decoding_metadata_dense_fp8(
self._graph_seq_lens[:batch_size],
self.num_q_heads,
1, # MQA for the decode path
16,
)
else:
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata(
self._graph_seq_lens[:batch_size],
self.num_q_heads,
1, # MQA for the decode path
)
self._graph_decoder_tile_scheduler_metadata.copy_(
decoder_tile_scheduler_metadata)
......
......@@ -98,8 +98,12 @@ class Attention(nn.Module):
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
self._k_scale = torch.ones((1), dtype=torch.float32)
self._v_scale = torch.ones((1), dtype=torch.float32)
else:
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
# FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well.
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
......
......@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe,
get_mla_metadata,
flash_mla_with_kvcache_fp8,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
......@@ -72,12 +73,21 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens,
self.num_q_heads,
1, # MQA for the decode path
)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_MLA_FP8:
tile_scheduler_metadata, num_splits = \
get_mla_decoding_metadata_dense_fp8(
seq_lens,
self.num_q_heads,
1, # MQA for the decode path
16,
)
else:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens,
self.num_q_heads,
1, # MQA for the decode path
)
if self.runner.full_cuda_graph:
# First time around (CUDAGraph capture), allocate the static buffer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment