Unverified Commit 386e3911 authored by jiaxingli's avatar jiaxingli Committed by GitHub
Browse files

Fix: implement deterministic backward in mha (#748)

* fix deterministic

* fix deterministic
parent 1a2c3e8c
...@@ -61,7 +61,7 @@ class FlashSelfAttention(nn.Module): ...@@ -61,7 +61,7 @@ class FlashSelfAttention(nn.Module):
(default: 0.0) (default: 0.0)
""" """
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None): def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False):
super().__init__() super().__init__()
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
...@@ -69,6 +69,7 @@ class FlashSelfAttention(nn.Module): ...@@ -69,6 +69,7 @@ class FlashSelfAttention(nn.Module):
self.softmax_scale = softmax_scale self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout) self.drop = nn.Dropout(attention_dropout)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
self.deterministic = deterministic
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
"""Implements the multihead softmax attention. """Implements the multihead softmax attention.
...@@ -103,6 +104,7 @@ class FlashSelfAttention(nn.Module): ...@@ -103,6 +104,7 @@ class FlashSelfAttention(nn.Module):
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
causal=causal, causal=causal,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
deterministic=self.deterministic,
) )
else: else:
return flash_attn_qkvpacked_func( return flash_attn_qkvpacked_func(
...@@ -111,6 +113,7 @@ class FlashSelfAttention(nn.Module): ...@@ -111,6 +113,7 @@ class FlashSelfAttention(nn.Module):
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
causal=causal, causal=causal,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
deterministic=self.deterministic,
) )
...@@ -125,7 +128,7 @@ class FlashCrossAttention(nn.Module): ...@@ -125,7 +128,7 @@ class FlashCrossAttention(nn.Module):
(default: 0.0) (default: 0.0)
""" """
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None): def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False):
super().__init__() super().__init__()
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
...@@ -133,6 +136,7 @@ class FlashCrossAttention(nn.Module): ...@@ -133,6 +136,7 @@ class FlashCrossAttention(nn.Module):
self.softmax_scale = softmax_scale self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout) self.drop = nn.Dropout(attention_dropout)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
self.deterministic = deterministic
def forward( def forward(
self, self,
...@@ -180,6 +184,7 @@ class FlashCrossAttention(nn.Module): ...@@ -180,6 +184,7 @@ class FlashCrossAttention(nn.Module):
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
causal=causal, causal=causal,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
deterministic=self.deterministic,
) )
else: else:
batch_size, seqlen_q = q.shape[0], q.shape[1] batch_size, seqlen_q = q.shape[0], q.shape[1]
...@@ -192,6 +197,7 @@ class FlashCrossAttention(nn.Module): ...@@ -192,6 +197,7 @@ class FlashCrossAttention(nn.Module):
causal=causal, causal=causal,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
deterministic=self.deterministic,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment