Commit a190df01 authored by Tri Dao's avatar Tri Dao
Browse files

Add window_size option to ParallelMHA

parent 2423cca3
......@@ -747,6 +747,7 @@ class ParallelMHA(nn.Module):
rotary_emb_scale_base=None,
rotary_emb_interleaved=False,
use_alibi=False,
window_size=(-1, -1),
use_flash_attn=False,
checkpointing=False,
sequence_parallel=True,
......@@ -793,6 +794,8 @@ class ParallelMHA(nn.Module):
)
else:
alibi_slopes = None
if window_size != (-1, -1):
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
if self.rotary_emb_dim > 0:
assert RotaryEmbedding is not None, "rotary_emb is not installed"
......@@ -816,12 +819,12 @@ class ParallelMHA(nn.Module):
**factory_kwargs,
)
inner_attn_cls = (
partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
if use_flash_attn
else SelfAttention
)
inner_cross_attn_cls = (
partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
if use_flash_attn
else CrossAttention
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment