Commit f9b567df authored by zhuwenwen's avatar zhuwenwen
Browse files

add softcap interface for gemma-2

parent 6dc7aa42
......@@ -350,10 +350,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if blocksparse_params is not None:
raise ValueError(
"ROCmFlashAttention does not support blocksparse attention.")
'''
if logits_soft_cap is not None:
raise ValueError(
"ROCmFlashAttention does not support attention logits soft "
"capping.")
'''
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
......@@ -566,6 +573,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
)
# common code for prefill
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment