Commit 9dbc491a authored by Tri Dao's avatar Tri Dao
Browse files

Rename, add benchmarking script

parent 1fcbe6f0
Our implementation uses Apex's
[FMHA](https://github.com/NVIDIA/apex/tree/master/apex/contrib/csrc/fmha) code
as a starting point.
We thank [Young-jun Ko](https://yjk21.github.io/) for the in-depth explanation of his FMHA implementation
and for his thoughtful answers to our questions about CUDA.
...@@ -5,11 +5,11 @@ import torch.nn as nn ...@@ -5,11 +5,11 @@ import torch.nn as nn
from einops import rearrange from einops import rearrange
from rotary import RotaryEmbedding, RotaryEmbedding2D from rotary import RotaryEmbedding, RotaryEmbedding2D
from stream_attn_interface import stream_attn_func from flash_attn_interface import flash_attn_func
from bert_padding import unpad_input, pad_input, index_first_axis from bert_padding import unpad_input, pad_input, index_first_axis
class StreamingAttention(nn.Module): class FlashAttention(nn.Module):
"""Implement the scaled dot product attention with softmax. """Implement the scaled dot product attention with softmax.
Arguments Arguments
--------- ---------
...@@ -49,7 +49,7 @@ class StreamingAttention(nn.Module): ...@@ -49,7 +49,7 @@ class StreamingAttention(nn.Module):
max_s = seqlen max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device) device=qkv.device)
output = stream_attn_func(qkv, cu_seqlens, self.dropout_p if self.training else 0.0, output = flash_attn_func(qkv, cu_seqlens, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal) max_s, softmax_scale=self.softmax_temp, causal=causal)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
else: else:
...@@ -58,7 +58,7 @@ class StreamingAttention(nn.Module): ...@@ -58,7 +58,7 @@ class StreamingAttention(nn.Module):
x = rearrange(qkv, 'b s three h d -> b s (three h d)') x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool) x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
output_unpad = stream_attn_func(x_unpad, cu_seqlens, output_unpad = flash_attn_func(x_unpad, cu_seqlens,
self.dropout_p if self.training else 0.0, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal) max_s, softmax_scale=self.softmax_temp, causal=causal)
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
...@@ -66,14 +66,14 @@ class StreamingAttention(nn.Module): ...@@ -66,14 +66,14 @@ class StreamingAttention(nn.Module):
'b s (h d) -> b s h d', h=nheads) 'b s (h d) -> b s h d', h=nheads)
else: else:
assert max_s is not None assert max_s is not None
output = stream_attn_func(qkv, cu_seqlens, output = flash_attn_func(qkv, cu_seqlens,
self.dropout_p if self.training else 0.0, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal) max_s, softmax_scale=self.softmax_temp, causal=causal)
return output, None return output, None
class StreamingMHA(nn.Module): class FlashMHA(nn.Module):
def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0, def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0,
causal=False, use_rotary_emb=None, device=None, dtype=None, **kwargs) -> None: causal=False, use_rotary_emb=None, device=None, dtype=None, **kwargs) -> None:
...@@ -96,7 +96,7 @@ class StreamingMHA(nn.Module): ...@@ -96,7 +96,7 @@ class StreamingMHA(nn.Module):
self.rotary_emb = RotaryEmbedding2D(self.head_dim) self.rotary_emb = RotaryEmbedding2D(self.head_dim)
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.inner_attn = StreamingAttention(attention_dropout=attention_dropout, **factory_kwargs) self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None, def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None,
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import stream_attn_cuda import flash_attn_cuda
def _stream_attn_forward(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal, return_softmax): def _flash_attn_forward(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal, return_softmax):
context, softmax_lse, *rest = stream_attn_cuda.fwd(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, context, softmax_lse, *rest = flash_attn_cuda.fwd(qkv, cu_seqlens, dropout_p, max_s, softmax_scale,
False, causal, return_softmax, None) False, causal, return_softmax, None)
# if context.isnan().any() or softmax_lse.isnan().any(): # if context.isnan().any() or softmax_lse.isnan().any():
# breakpoint() # breakpoint()
...@@ -14,16 +14,16 @@ def _stream_attn_forward(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causa ...@@ -14,16 +14,16 @@ def _stream_attn_forward(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causa
return context, softmax_lse, S_dmask return context, softmax_lse, S_dmask
def _stream_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, dropout_p, max_s, def _flash_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, dropout_p, max_s,
softmax_scale, causal): softmax_scale, causal):
dqkv, dp, softmax_d = stream_attn_cuda.bwd(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, dropout_p, dqkv, dp, softmax_d = flash_attn_cuda.bwd(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, dropout_p,
softmax_scale, max_s, False, causal, None) softmax_scale, max_s, False, causal, None)
# if dqkv.isnan().any() or softmax_d.isnan().any(): # if dqkv.isnan().any() or softmax_d.isnan().any():
# breakpoint() # breakpoint()
return dqkv return dqkv
class StreamAttnFun(torch.autograd.Function): class FlashAttnFun(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal): def forward(ctx, qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal):
...@@ -31,7 +31,7 @@ class StreamAttnFun(torch.autograd.Function): ...@@ -31,7 +31,7 @@ class StreamAttnFun(torch.autograd.Function):
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _stream_attn_forward( context, softmax_lse, S_dmask = _flash_attn_forward(
qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal=causal, return_softmax=False qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal=causal, return_softmax=False
) )
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state) ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state)
...@@ -48,7 +48,7 @@ class StreamAttnFun(torch.autograd.Function): ...@@ -48,7 +48,7 @@ class StreamAttnFun(torch.autograd.Function):
cur_rng_state = torch.cuda.get_rng_state() cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state) torch.cuda.set_rng_state(rng_state)
# S_dmask is None, temporarily use another tensor just to get it running # S_dmask is None, temporarily use another tensor just to get it running
dqkv = _stream_attn_backward( dqkv = _flash_attn_backward(
dout, qkv, context, context, softmax_lse, cu_seqlens, ctx.dropout_p, dout, qkv, context, context, softmax_lse, cu_seqlens, ctx.dropout_p,
ctx.max_s, ctx.softmax_scale, ctx.causal ctx.max_s, ctx.softmax_scale, ctx.causal
) )
...@@ -59,7 +59,7 @@ class StreamAttnFun(torch.autograd.Function): ...@@ -59,7 +59,7 @@ class StreamAttnFun(torch.autograd.Function):
# We duplicate code to return both the output and the softmax for testing # We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed. # Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class StreamAttnFunWithS(torch.autograd.Function): class FlashAttnFunWithS(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal): def forward(ctx, qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal):
...@@ -67,7 +67,7 @@ class StreamAttnFunWithS(torch.autograd.Function): ...@@ -67,7 +67,7 @@ class StreamAttnFunWithS(torch.autograd.Function):
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _stream_attn_forward( context, softmax_lse, S_dmask = _flash_attn_forward(
qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal=causal, return_softmax=True qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal=causal, return_softmax=True
) )
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state) ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state)
...@@ -83,7 +83,7 @@ class StreamAttnFunWithS(torch.autograd.Function): ...@@ -83,7 +83,7 @@ class StreamAttnFunWithS(torch.autograd.Function):
if rng_state is not None: if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state() cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state) torch.cuda.set_rng_state(rng_state)
dqkv = _stream_attn_backward( dqkv = _flash_attn_backward(
dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, ctx.dropout_p, dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, ctx.dropout_p,
ctx.max_s, ctx.softmax_scale, ctx.causal ctx.max_s, ctx.softmax_scale, ctx.causal
) )
...@@ -92,9 +92,9 @@ class StreamAttnFunWithS(torch.autograd.Function): ...@@ -92,9 +92,9 @@ class StreamAttnFunWithS(torch.autograd.Function):
return dqkv, None, None, None, None, None return dqkv, None, None, None, None, None
def stream_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False, def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,
return_attn_probs=False): return_attn_probs=False):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
""" """
func = StreamAttnFun if not return_attn_probs else StreamAttnFunWithS func = FlashAttnFun if not return_attn_probs else FlashAttnFunWithS
return func.apply(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal) return func.apply(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal)
...@@ -6,11 +6,11 @@ from einops import rearrange ...@@ -6,11 +6,11 @@ from einops import rearrange
import hydra import hydra
from stream_blocksparse_attn_interface import stream_blocksparse_attn_func from flash_blocksparse_attn_interface import flash_blocksparse_attn_func
from stream_blocksparse_attn_interface import convert_blockmask from flash_blocksparse_attn_interface import convert_blockmask
from bert_padding import unpad_input, pad_input, index_first_axis from bert_padding import unpad_input, pad_input, index_first_axis
class StreamingBlocksparseAttention(nn.Module): class FlashBlocksparseAttention(nn.Module):
"""Implement the scaled dot product attention with softmax. """Implement the scaled dot product attention with softmax.
Arguments Arguments
--------- ---------
...@@ -63,7 +63,7 @@ class StreamingBlocksparseAttention(nn.Module): ...@@ -63,7 +63,7 @@ class StreamingBlocksparseAttention(nn.Module):
max_s = seqlen max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device) device=qkv.device)
output = stream_blocksparse_attn_func( output = flash_blocksparse_attn_func(
qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0, qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal max_s, softmax_scale=self.softmax_temp, causal=causal
) )
...@@ -74,7 +74,7 @@ class StreamingBlocksparseAttention(nn.Module): ...@@ -74,7 +74,7 @@ class StreamingBlocksparseAttention(nn.Module):
x = rearrange(qkv, 'b s three h d -> b s (three h d)') x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool) x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
output_unpad = stream_blocksparse_attn_func( output_unpad = flash_blocksparse_attn_func(
x_unpad, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0, x_unpad, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal max_s, softmax_scale=self.softmax_temp, causal=causal
) )
...@@ -89,12 +89,12 @@ class StreamingBlocksparseAttention(nn.Module): ...@@ -89,12 +89,12 @@ class StreamingBlocksparseAttention(nn.Module):
assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1] assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1]
blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256] blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256]
if convert_mask: if convert_mask:
output = stream_blocksparse_attn_func( output = flash_blocksparse_attn_func(
qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0, qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal max_s, softmax_scale=self.softmax_temp, causal=causal
) )
else: else:
output = stream_blocksparse_attn_func( output = flash_blocksparse_attn_func(
qkv, cu_seqlens, self.blockmask_converted, self.dropout_p if self.training else 0.0, qkv, cu_seqlens, self.blockmask_converted, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal, max_s, softmax_scale=self.softmax_temp, causal=causal,
convert_mask=False, convert_mask=False,
...@@ -103,7 +103,7 @@ class StreamingBlocksparseAttention(nn.Module): ...@@ -103,7 +103,7 @@ class StreamingBlocksparseAttention(nn.Module):
return output, None return output, None
class StreamingBlocksparseMHA(nn.Module): class FlashBlocksparseMHA(nn.Module):
def __init__(self, embed_dim, num_heads, sparsity_config, bias=True, batch_first=True, def __init__(self, embed_dim, num_heads, sparsity_config, bias=True, batch_first=True,
attention_dropout=0.0, causal=False, max_seq_length=2048, attention_dropout=0.0, causal=False, max_seq_length=2048,
...@@ -120,7 +120,7 @@ class StreamingBlocksparseMHA(nn.Module): ...@@ -120,7 +120,7 @@ class StreamingBlocksparseMHA(nn.Module):
assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64" assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64"
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.inner_attn = StreamingBlocksparseAttention( self.inner_attn = FlashBlocksparseAttention(
sparsity_config, attention_dropout=attention_dropout, sparsity_config, attention_dropout=attention_dropout,
max_seq_length=max_seq_length, **factory_kwargs max_seq_length=max_seq_length, **factory_kwargs
) )
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import stream_attn_cuda import flash_attn_cuda
def convert_blockmask(blockmask, causal): def convert_blockmask(blockmask, causal):
...@@ -40,9 +40,9 @@ def convert_blockmask(blockmask, causal): ...@@ -40,9 +40,9 @@ def convert_blockmask(blockmask, causal):
return nonzero_idx.T.contiguous().to(dtype=torch.int32) return nonzero_idx.T.contiguous().to(dtype=torch.int32)
def _stream_blocksparse_attn_forward(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, def _flash_blocksparse_attn_forward(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale,
causal, return_softmax): causal, return_softmax):
context, softmax_lse, *rest = stream_attn_cuda.fwd_block(qkv, cu_seqlens, blockmask, dropout_p, context, softmax_lse, *rest = flash_attn_cuda.fwd_block(qkv, cu_seqlens, blockmask, dropout_p,
max_s, softmax_scale, causal, max_s, softmax_scale, causal,
return_softmax, None) return_softmax, None)
# if context.isnan().any() or softmax_lse.isnan().any(): # if context.isnan().any() or softmax_lse.isnan().any():
...@@ -51,9 +51,9 @@ def _stream_blocksparse_attn_forward(qkv, cu_seqlens, blockmask, dropout_p, max_ ...@@ -51,9 +51,9 @@ def _stream_blocksparse_attn_forward(qkv, cu_seqlens, blockmask, dropout_p, max_
return context, softmax_lse, S_dmask return context, softmax_lse, S_dmask
def _stream_blocksparse_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, blockmask, def _flash_blocksparse_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, blockmask,
dropout_p, max_s, softmax_scale, causal): dropout_p, max_s, softmax_scale, causal):
dqkv, dp, softmax_d = stream_attn_cuda.bwd_block(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens,
blockmask, dropout_p, softmax_scale, max_s, blockmask, dropout_p, softmax_scale, max_s,
causal, None) causal, None)
# if dqkv.isnan().any() or softmax_d.isnan().any(): # if dqkv.isnan().any() or softmax_d.isnan().any():
...@@ -61,7 +61,7 @@ def _stream_blocksparse_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_s ...@@ -61,7 +61,7 @@ def _stream_blocksparse_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_s
return dqkv return dqkv
class StreamBlocksparseAttnFun(torch.autograd.Function): class FlashBlocksparseAttnFun(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal): def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
...@@ -69,7 +69,7 @@ class StreamBlocksparseAttnFun(torch.autograd.Function): ...@@ -69,7 +69,7 @@ class StreamBlocksparseAttnFun(torch.autograd.Function):
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _stream_blocksparse_attn_forward( context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal,
return_softmax=False return_softmax=False
) )
...@@ -87,7 +87,7 @@ class StreamBlocksparseAttnFun(torch.autograd.Function): ...@@ -87,7 +87,7 @@ class StreamBlocksparseAttnFun(torch.autograd.Function):
cur_rng_state = torch.cuda.get_rng_state() cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state) torch.cuda.set_rng_state(rng_state)
# S_dmask is None, temporarily use another tensor just to get it running # S_dmask is None, temporarily use another tensor just to get it running
dqkv = _stream_blocksparse_attn_backward( dqkv = _flash_blocksparse_attn_backward(
dout, qkv, context, context, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p, dout, qkv, context, context, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p,
ctx.max_s, ctx.softmax_scale, ctx.causal ctx.max_s, ctx.softmax_scale, ctx.causal
) )
...@@ -98,7 +98,7 @@ class StreamBlocksparseAttnFun(torch.autograd.Function): ...@@ -98,7 +98,7 @@ class StreamBlocksparseAttnFun(torch.autograd.Function):
# We duplicate code to return both the output and the softmax for testing # We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed. # Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class StreamBlocksparseAttnFunWithS(torch.autograd.Function): class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal): def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
...@@ -106,7 +106,7 @@ class StreamBlocksparseAttnFunWithS(torch.autograd.Function): ...@@ -106,7 +106,7 @@ class StreamBlocksparseAttnFunWithS(torch.autograd.Function):
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _stream_blocksparse_attn_forward( context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal,
return_softmax=True return_softmax=True
) )
...@@ -123,7 +123,7 @@ class StreamBlocksparseAttnFunWithS(torch.autograd.Function): ...@@ -123,7 +123,7 @@ class StreamBlocksparseAttnFunWithS(torch.autograd.Function):
if rng_state is not None: if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state() cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state) torch.cuda.set_rng_state(rng_state)
dqkv = _stream_blocksparse_attn_backward( dqkv = _flash_blocksparse_attn_backward(
dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p, dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p,
ctx.max_s, ctx.softmax_scale, ctx.causal ctx.max_s, ctx.softmax_scale, ctx.causal
) )
...@@ -132,11 +132,11 @@ class StreamBlocksparseAttnFunWithS(torch.autograd.Function): ...@@ -132,11 +132,11 @@ class StreamBlocksparseAttnFunWithS(torch.autograd.Function):
return dqkv, None, None, None, None, None, None return dqkv, None, None, None, None, None, None
def stream_blocksparse_attn_func(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale=None, def flash_blocksparse_attn_func(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale=None,
causal=False, return_attn_probs=False, convert_mask=True): causal=False, return_attn_probs=False, convert_mask=True):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
""" """
func = StreamBlocksparseAttnFun if not return_attn_probs else StreamBlocksparseAttnFunWithS func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS
if convert_mask: if convert_mask:
blockmask = convert_blockmask(blockmask, causal=causal) blockmask = convert_blockmask(blockmask, causal=causal)
return func.apply(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal) return func.apply(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment