Commit f9b567df authored by zhuwenwen's avatar zhuwenwen
Browse files

add softcap interface for gemma-2

parent 6dc7aa42
...@@ -350,10 +350,17 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -350,10 +350,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if blocksparse_params is not None: if blocksparse_params is not None:
raise ValueError( raise ValueError(
"ROCmFlashAttention does not support blocksparse attention.") "ROCmFlashAttention does not support blocksparse attention.")
'''
if logits_soft_cap is not None: if logits_soft_cap is not None:
raise ValueError( raise ValueError(
"ROCmFlashAttention does not support attention logits soft " "ROCmFlashAttention does not support attention logits soft "
"capping.") "capping.")
'''
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
...@@ -566,6 +573,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -566,6 +573,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
causal=True, causal=True,
window_size=self.sliding_window, window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
) )
# common code for prefill # common code for prefill
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment