Unverified Commit e6c47977 authored by Divakar Verma's avatar Divakar Verma Committed by GitHub
Browse files

[ROCm][Quantization] add fp8xfp8 attn support for rocm_aiter_unified_attn (#36927)


Signed-off-by: default avatarDivakar Verma <divakar.verma@amd.com>
parent 09e4576f
...@@ -125,6 +125,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): ...@@ -125,6 +125,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
from aiter.ops.triton.unified_attention import unified_attention from aiter.ops.triton.unified_attention import unified_attention
self.unified_attention = unified_attention self.unified_attention = unified_attention
self.supports_quant_query_input = True
def forward( def forward(
self, self,
...@@ -190,12 +191,20 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): ...@@ -190,12 +191,20 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
key_cache, value_cache = kv_cache.unbind(0) key_cache, value_cache = kv_cache.unbind(0)
softmax_scale = self.scale
fp8_post_attn_v_rescale = False
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, ( # When Q is FP8, triton kernel skips K/V dequant (for fp8xfp8 matmul).
"A non 1.0 q_scale is not currently supported." # Compensate by absorbing q_scale and k_scale into softmax_scale, and
) # v_scale into output_scale (or post-multiplying if no fusion).
if query.dtype == self.fp8_dtype:
softmax_scale = self.scale * layer._q_scale_float * layer._k_scale_float
if output_scale is not None:
output_scale = output_scale / layer._v_scale_float
else:
fp8_post_attn_v_rescale = True
cu_seqlens_q = attn_metadata.query_start_loc cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens seqused_k = attn_metadata.seq_lens
...@@ -217,19 +226,22 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): ...@@ -217,19 +226,22 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k, seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k, max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale, softmax_scale=softmax_scale,
causal=True, causal=True,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window, window_size=self.sliding_window,
block_table=block_table, block_table=block_table,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
q_descale=None, # Not supported q_descale=None, # q_scale absorbed into softmax_scale
k_descale=layer._k_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks, sinks=self.sinks,
output_scale=output_scale, output_scale=output_scale,
) )
if fp8_post_attn_v_rescale:
output[:num_actual_tokens].mul_(layer._v_scale_float)
return output return output
def do_kv_cache_update( def do_kv_cache_update(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment