@@ -70,7 +70,7 @@ FlashAttention-2 currently supports:
...
@@ -70,7 +70,7 @@ FlashAttention-2 currently supports:
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
GPUs for now.
GPUs for now.
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800.
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
@@ -783,8 +783,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
...
@@ -783,8 +783,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK(batch_size>0,"batch size must be positive");
TORCH_CHECK(batch_size>0,"batch size must be positive");
TORCH_CHECK(head_size%8==0,"head_size should be a multiple of 8");
TORCH_CHECK(head_size%8==0,"head_size should be a multiple of 8");
TORCH_CHECK(head_size<=256,"FlashAttention backward only supports head dimension at most 256");
TORCH_CHECK(head_size<=256,"FlashAttention backward only supports head dimension at most 256");
if(head_size>192){
if(head_size>192&&(head_size<=224||is_dropout)){
TORCH_CHECK(is_sm80||is_sm90,"FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800");
TORCH_CHECK(is_sm80||is_sm90,"FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
}
}
TORCH_CHECK(num_heads%num_heads_k==0,"Number of heads in key/value must divide number of heads in query");
TORCH_CHECK(num_heads%num_heads_k==0,"Number of heads in key/value must divide number of heads in query");
...
@@ -1020,8 +1020,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -1020,8 +1020,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK(batch_size>0,"batch size must be positive");
TORCH_CHECK(batch_size>0,"batch size must be positive");
TORCH_CHECK(head_size%8==0,"head_size should be a multiple of 8");
TORCH_CHECK(head_size%8==0,"head_size should be a multiple of 8");
TORCH_CHECK(head_size<=256,"FlashAttention backward only supports head dimension at most 256");
TORCH_CHECK(head_size<=256,"FlashAttention backward only supports head dimension at most 256");
if(head_size>192){
if(head_size>192&&(head_size<=224||is_dropout)){
TORCH_CHECK(is_sm80||is_sm90,"FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800");
TORCH_CHECK(is_sm80||is_sm90,"FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
}
}
TORCH_CHECK(num_heads%num_heads_k==0,"Number of heads in key/value must divide number of heads in query");
TORCH_CHECK(num_heads%num_heads_k==0,"Number of heads in key/value must divide number of heads in query");