Commit c77bc77c authored by xiabo's avatar xiabo
Browse files

1、kvcache支持fp8的scale

parent 3d01cce7
...@@ -700,8 +700,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -700,8 +700,9 @@ class FlashAttentionImpl(AttentionImpl):
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata, suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
# q_descale=layer._q_scale, # q_descale=layer._q_scale,
# k_descale=layer._k_scale, q_descale=None,
# v_descale=layer._v_scale, k_descale=layer._k_scale,
v_descale=layer._v_scale,
) )
return output return output
...@@ -779,6 +780,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -779,6 +780,9 @@ class FlashAttentionImpl(AttentionImpl):
# q_descale=layer._q_scale.expand(descale_shape), # q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape), # k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape), # v_descale=layer._v_scale.expand(descale_shape),
q_descale=None,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
is_prefix_cache=False, is_prefix_cache=False,
) )
...@@ -932,12 +936,12 @@ def cascade_attention( ...@@ -932,12 +936,12 @@ def cascade_attention(
return_softmax_lse=True, return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata, scheduler_metadata=prefix_scheduler_metadata,
# fa_version=fa_version, # fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape) q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None, if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape) k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None, if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape) v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None, if v_descale is not None else None,
is_prefix_cache=True, is_prefix_cache=True,
) )
...@@ -985,12 +989,12 @@ def cascade_attention( ...@@ -985,12 +989,12 @@ def cascade_attention(
return_softmax_lse=True, return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata, scheduler_metadata=suffix_scheduler_metadata,
# fa_version=fa_version, # fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape) q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None, if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape) k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None, if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape) v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None, if v_descale is not None else None,
is_prefix_cache=True, is_prefix_cache=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment