Unverified Commit e1a7fe4a authored by Mickaël Seznec's avatar Mickaël Seznec Committed by GitHub
Browse files

[BugFix] fix: aot passes kvcache dtype information (#19750)


Signed-off-by: default avatarMickael Seznec <mickael@mistral.ai>
parent 82de9b9d
......@@ -99,6 +99,13 @@ class FlashAttentionBackend(AttentionBackend):
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
@staticmethod
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@dataclass
class FlashAttentionMetadata:
......@@ -161,6 +168,7 @@ class FlashAttentionMetadataBuilder(
self.parallel_config)
self.num_heads_kv = self.model_config.get_num_kv_heads(
self.parallel_config)
self.kv_cache_dtype = kv_cache_spec.dtype
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
......@@ -239,17 +247,24 @@ class FlashAttentionMetadataBuilder(
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
cache_dtype = self.cache_config.cache_dtype
if cache_dtype.startswith("fp8"):
qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
cache_dtype)
else:
qkv_dtype = self.kv_cache_dtype
if aot_schedule:
return get_scheduler_metadata(
batch_size=batch_size,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
cache_seqlens=seqlens,
num_heads_q=self.num_heads_q,
num_heads_kv=self.num_heads_kv,
headdim=self.headdim,
page_size=self.block_size,
cache_seqlens=seqlens,
qkv_dtype=qkv_dtype,
cu_seqlens_q=cu_query_lens,
page_size=self.block_size,
causal=causal,
window_size=self.aot_sliding_window,
num_splits=self.max_num_splits,
......@@ -474,8 +489,10 @@ class FlashAttentionImpl(AttentionImpl):
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
self.kv_cache_dtype)
key_cache = key_cache.view(dtype)
value_cache = value_cache.view(dtype)
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment