Commit ef0ed106 authored by Tri Dao's avatar Tri Dao
Browse files

Add window_size option to MHA and GPT

parent dc72d960
...@@ -78,6 +78,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt ...@@ -78,6 +78,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None) rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None)
rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False) rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False)
use_alibi = getattr(config, "use_alibi", False) use_alibi = getattr(config, "use_alibi", False)
window_size = getattr(config, "window_size", (-1, -1))
use_flash_attn = getattr(config, "use_flash_attn", False) use_flash_attn = getattr(config, "use_flash_attn", False)
fused_bias_fc = getattr(config, "fused_bias_fc", False) fused_bias_fc = getattr(config, "fused_bias_fc", False)
if not fused_bias_fc: if not fused_bias_fc:
...@@ -110,6 +111,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt ...@@ -110,6 +111,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
rotary_emb_scale_base=rotary_emb_scale_base, rotary_emb_scale_base=rotary_emb_scale_base,
rotary_emb_interleaved=rotary_emb_interleaved, rotary_emb_interleaved=rotary_emb_interleaved,
use_alibi=use_alibi, use_alibi=use_alibi,
window_size=window_size,
use_flash_attn=use_flash_attn, use_flash_attn=use_flash_attn,
**serial_kwargs, **serial_kwargs,
**parallel_kwargs, **parallel_kwargs,
......
...@@ -61,7 +61,15 @@ class FlashSelfAttention(nn.Module): ...@@ -61,7 +61,15 @@ class FlashSelfAttention(nn.Module):
(default: 0.0) (default: 0.0)
""" """
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False): def __init__(
self,
causal=False,
softmax_scale=None,
attention_dropout=0.0,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
):
super().__init__() super().__init__()
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
...@@ -69,6 +77,7 @@ class FlashSelfAttention(nn.Module): ...@@ -69,6 +77,7 @@ class FlashSelfAttention(nn.Module):
self.softmax_scale = softmax_scale self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout) self.drop = nn.Dropout(attention_dropout)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
self.window_size = window_size
self.deterministic = deterministic self.deterministic = deterministic
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
...@@ -104,6 +113,7 @@ class FlashSelfAttention(nn.Module): ...@@ -104,6 +113,7 @@ class FlashSelfAttention(nn.Module):
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
causal=causal, causal=causal,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
window_size=self.window_size,
deterministic=self.deterministic, deterministic=self.deterministic,
) )
else: else:
...@@ -113,6 +123,7 @@ class FlashSelfAttention(nn.Module): ...@@ -113,6 +123,7 @@ class FlashSelfAttention(nn.Module):
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
causal=causal, causal=causal,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
window_size=self.window_size,
deterministic=self.deterministic, deterministic=self.deterministic,
) )
...@@ -128,7 +139,15 @@ class FlashCrossAttention(nn.Module): ...@@ -128,7 +139,15 @@ class FlashCrossAttention(nn.Module):
(default: 0.0) (default: 0.0)
""" """
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False): def __init__(
self,
causal=False,
softmax_scale=None,
attention_dropout=0.0,
alibi_slopes=None,
window_size=(-1, -1),
deterministic=False,
):
super().__init__() super().__init__()
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
...@@ -136,6 +155,7 @@ class FlashCrossAttention(nn.Module): ...@@ -136,6 +155,7 @@ class FlashCrossAttention(nn.Module):
self.softmax_scale = softmax_scale self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout) self.drop = nn.Dropout(attention_dropout)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
self.window_size = window_size
self.deterministic = deterministic self.deterministic = deterministic
def forward( def forward(
...@@ -184,6 +204,7 @@ class FlashCrossAttention(nn.Module): ...@@ -184,6 +204,7 @@ class FlashCrossAttention(nn.Module):
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
causal=causal, causal=causal,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
window_size=self.window_size,
deterministic=self.deterministic, deterministic=self.deterministic,
) )
else: else:
...@@ -197,6 +218,7 @@ class FlashCrossAttention(nn.Module): ...@@ -197,6 +218,7 @@ class FlashCrossAttention(nn.Module):
causal=causal, causal=causal,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
window_size=self.window_size,
deterministic=self.deterministic, deterministic=self.deterministic,
) )
...@@ -372,6 +394,7 @@ class MHA(nn.Module): ...@@ -372,6 +394,7 @@ class MHA(nn.Module):
rotary_emb_scale_base=None, rotary_emb_scale_base=None,
rotary_emb_interleaved=False, rotary_emb_interleaved=False,
use_alibi=False, use_alibi=False,
window_size=(-1, -1),
fused_bias_fc=False, fused_bias_fc=False,
use_flash_attn=False, use_flash_attn=False,
return_residual=False, return_residual=False,
...@@ -401,6 +424,8 @@ class MHA(nn.Module): ...@@ -401,6 +424,8 @@ class MHA(nn.Module):
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device) alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
else: else:
alibi_slopes = None alibi_slopes = None
if window_size != (-1, -1):
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
self.num_heads = num_heads self.num_heads = num_heads
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
...@@ -431,12 +456,12 @@ class MHA(nn.Module): ...@@ -431,12 +456,12 @@ class MHA(nn.Module):
) )
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
inner_attn_cls = ( inner_attn_cls = (
partial(FlashSelfAttention, alibi_slopes=alibi_slopes) partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
if use_flash_attn if use_flash_attn
else SelfAttention else SelfAttention
) )
inner_cross_attn_cls = ( inner_cross_attn_cls = (
partial(FlashCrossAttention, alibi_slopes=alibi_slopes) partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
if use_flash_attn if use_flash_attn
else CrossAttention else CrossAttention
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment