Commit a190df01 authored by Tri Dao's avatar Tri Dao
Browse files

Add window_size option to ParallelMHA

parent 2423cca3
...@@ -747,6 +747,7 @@ class ParallelMHA(nn.Module): ...@@ -747,6 +747,7 @@ class ParallelMHA(nn.Module):
rotary_emb_scale_base=None, rotary_emb_scale_base=None,
rotary_emb_interleaved=False, rotary_emb_interleaved=False,
use_alibi=False, use_alibi=False,
window_size=(-1, -1),
use_flash_attn=False, use_flash_attn=False,
checkpointing=False, checkpointing=False,
sequence_parallel=True, sequence_parallel=True,
...@@ -793,6 +794,8 @@ class ParallelMHA(nn.Module): ...@@ -793,6 +794,8 @@ class ParallelMHA(nn.Module):
) )
else: else:
alibi_slopes = None alibi_slopes = None
if window_size != (-1, -1):
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
assert RotaryEmbedding is not None, "rotary_emb is not installed" assert RotaryEmbedding is not None, "rotary_emb is not installed"
...@@ -816,12 +819,12 @@ class ParallelMHA(nn.Module): ...@@ -816,12 +819,12 @@ class ParallelMHA(nn.Module):
**factory_kwargs, **factory_kwargs,
) )
inner_attn_cls = ( inner_attn_cls = (
partial(FlashSelfAttention, alibi_slopes=alibi_slopes) partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
if use_flash_attn if use_flash_attn
else SelfAttention else SelfAttention
) )
inner_cross_attn_cls = ( inner_cross_attn_cls = (
partial(FlashCrossAttention, alibi_slopes=alibi_slopes) partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
if use_flash_attn if use_flash_attn
else CrossAttention else CrossAttention
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment