Commit b6aa059b authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani
Browse files

Add option for deterministic execution

parent 009a3e71
...@@ -50,7 +50,8 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens ...@@ -50,7 +50,8 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens
class FlashAttnQKVPackedFunc(torch.autograd.Function): class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax): def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal,
return_softmax, deterministic):
# Save rng_state because the backward pass will regenerate the dropout mask # Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
...@@ -65,6 +66,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -65,6 +66,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.max_seqlen = max_seqlen ctx.max_seqlen = max_seqlen
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -77,18 +79,19 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -77,18 +79,19 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
_flash_attn_backward( _flash_attn_backward(
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse, dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse,
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens,
ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
num_splits=1 if ctx.deterministic else 0,
) )
if rng_state is not None: if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state) torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None return dqkv, None, None, None, None, None, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function): class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, causal, return_softmax): softmax_scale, causal, return_softmax, deterministic):
# Save rng_state because the backward pass will regenerate the dropout mask # Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
...@@ -103,6 +106,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -103,6 +106,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.max_seqlen_k = max_seqlen_k ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -116,18 +120,19 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -116,18 +120,19 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
_flash_attn_backward( _flash_attn_backward(
dout, q, kv[:, 0], kv[:, 1], out, softmax_lse, dout, q, kv[:, 0], kv[:, 1], out, softmax_lse,
dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k, dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
num_splits=1 if ctx.deterministic else 0,
) )
if rng_state is not None: if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state) torch.cuda.set_rng_state(cur_rng_state)
return dq, dkv, None, None, None, None, None, None, None, None return dq, dkv, None, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function): class FlashAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, causal, return_softmax): softmax_scale, causal, return_softmax, deterministic):
# Save rng_state because the backward pass will regenerate the dropout mask # Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
...@@ -142,6 +147,7 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -142,6 +147,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.max_seqlen_k = max_seqlen_k ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -153,18 +159,19 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -153,18 +159,19 @@ class FlashAttnFunc(torch.autograd.Function):
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_backward( _flash_attn_backward(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
num_splits=1 if ctx.deterministic else 0,
) )
if rng_state is not None: if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state) torch.cuda.set_rng_state(cur_rng_state)
return dq, dk, dv, None, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None, None
class FlashAttnQKVPackedSplitFunc(torch.autograd.Function): class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, def forward(ctx, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p,
softmax_scale, causal, return_softmax): softmax_scale, causal, return_softmax, deterministic):
# Save rng_state because the backward pass will regenerate the dropout mask # Save rng_state because the backward pass will regenerate the dropout mask
if dropout_p > 0: if dropout_p > 0:
rng_state0 = torch.cuda.get_rng_state() rng_state0 = torch.cuda.get_rng_state()
...@@ -196,6 +203,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function): ...@@ -196,6 +203,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
ctx.batch_size0 = batch_size0 ctx.batch_size0 = batch_size0
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.deterministic = deterministic
if not return_softmax: if not return_softmax:
return out return out
else: else:
...@@ -223,7 +231,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function): ...@@ -223,7 +231,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse0, dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse0,
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[:batch_size0 + 1], dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[:batch_size0 + 1],
cu_seqlens[:batch_size0 + 1], ctx.max_seqlen0, ctx.max_seqlen0, ctx.dropout_p, cu_seqlens[:batch_size0 + 1], ctx.max_seqlen0, ctx.max_seqlen0, ctx.dropout_p,
ctx.softmax_scale, ctx.causal ctx.softmax_scale, ctx.causal, num_splits=1 if ctx.deterministic else 0,
) )
s = torch.cuda.Stream() s = torch.cuda.Stream()
with torch.cuda.stream(s): with torch.cuda.stream(s):
...@@ -231,16 +239,17 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function): ...@@ -231,16 +239,17 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse1, dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse1,
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[batch_size0:], dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[batch_size0:],
cu_seqlens[batch_size0:], ctx.max_seqlen1, ctx.max_seqlen1, ctx.dropout_p, cu_seqlens[batch_size0:], ctx.max_seqlen1, ctx.max_seqlen1, ctx.dropout_p,
ctx.softmax_scale, ctx.causal, generator=generator1 ctx.softmax_scale, ctx.causal, generator=generator1,
num_splits=1 if ctx.deterministic else 0,
) )
torch.cuda.current_stream().wait_stream(s) torch.cuda.current_stream().wait_stream(s)
if rng_state0 is not None: if rng_state0 is not None:
torch.cuda.set_rng_state(cur_rng_state) torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None, None, None return dqkv, None, None, None, None, None, None, None, None, None
def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None,
causal=False, return_attn_probs=False): causal=False, return_attn_probs=False, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Arguments: Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
...@@ -254,6 +263,7 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s ...@@ -254,6 +263,7 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return: Return:
out: (total, nheads, headdim). out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
...@@ -264,12 +274,12 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s ...@@ -264,12 +274,12 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
causal, return_attn_probs) causal, return_attn_probs, deterministic)
def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale=None, causal=False, dropout_p, softmax_scale=None, causal=False,
return_attn_probs=False): return_attn_probs=False, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Arguments: Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
...@@ -287,6 +297,7 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq ...@@ -287,6 +297,7 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return: Return:
out: (total, nheads, headdim). out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
...@@ -298,11 +309,12 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq ...@@ -298,11 +309,12 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
""" """
return FlashAttnKVPackedFunc.apply(q, kv, cu_seqlens_q, cu_seqlens_k, return FlashAttnKVPackedFunc.apply(q, kv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
return_attn_probs) return_attn_probs, deterministic)
def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale=None, causal=False, return_attn_probs=False): dropout_p, softmax_scale=None, causal=False, return_attn_probs=False,
deterministic=False):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Arguments: Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
...@@ -321,6 +333,7 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, ...@@ -321,6 +333,7 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return: Return:
out: (total, nheads, headdim). out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
...@@ -331,12 +344,12 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, ...@@ -331,12 +344,12 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, return_attn_probs) dropout_p, softmax_scale, causal, return_attn_probs, deterministic)
def flash_attn_unpadded_qkvpacked_split_func( def flash_attn_unpadded_qkvpacked_split_func(
qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, softmax_scale=None, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, softmax_scale=None,
causal=False, return_attn_probs=False): causal=False, return_attn_probs=False, deterministic=False):
""" """
Split attention into 2 kernels running on 2 separate streams for performance reason: Split attention into 2 kernels running on 2 separate streams for performance reason:
e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to
...@@ -358,6 +371,7 @@ def flash_attn_unpadded_qkvpacked_split_func( ...@@ -358,6 +371,7 @@ def flash_attn_unpadded_qkvpacked_split_func(
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return: Return:
out: (total, nheads, headdim). out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
...@@ -368,7 +382,8 @@ def flash_attn_unpadded_qkvpacked_split_func( ...@@ -368,7 +382,8 @@ def flash_attn_unpadded_qkvpacked_split_func(
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnQKVPackedSplitFunc.apply(qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, return FlashAttnQKVPackedSplitFunc.apply(qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0,
dropout_p, softmax_scale, causal, return_attn_probs) dropout_p, softmax_scale, causal, return_attn_probs,
deterministic)
def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False, def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment