Unverified Commit 1a456c7c authored by Douglas Lehr's avatar Douglas Lehr Committed by GitHub
Browse files

Aiter mha fp8 fix (#24991)


Signed-off-by: default avatarDoug Lehr <douglehr@amd.com>
Co-authored-by: default avatarDoug Lehr <douglehr@amd.com>
parent fedb75fa
...@@ -81,8 +81,8 @@ class AITERPagedAttention(PagedAttention): ...@@ -81,8 +81,8 @@ class AITERPagedAttention(PagedAttention):
blocksparse_head_sliding_step=blocksparse_head_sliding_step) blocksparse_head_sliding_step=blocksparse_head_sliding_step)
if "fp8" in kv_cache_dtype: if "fp8" in kv_cache_dtype:
key_cache = key_cache.view(torch.float8_e4m3fnuz) key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(torch.float8_e4m3fnuz) value_cache = value_cache.view(current_platform.fp8_dtype())
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention # use blocksparse paged attention
......
...@@ -479,8 +479,8 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -479,8 +479,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
) )
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fnuz) key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(torch.float8_e4m3fnuz) value_cache = value_cache.view(current_platform.fp8_dtype())
if not attn_metadata.use_cascade: if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc cu_seqlens_q = attn_metadata.query_start_loc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment