@@ -874,17 +874,18 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
...
@@ -874,17 +874,18 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK(batch_size>0,"batch size must be positive");
TORCH_CHECK(batch_size>0,"batch size must be positive");
TORCH_CHECK(head_size%8==0,"head_size should be a multiple of 8");
TORCH_CHECK(head_size%8==0,"head_size should be a multiple of 8");
TORCH_CHECK(head_size<=256,"FlashAttention backward only supports head dimension at most 256");
TORCH_CHECK(head_size<=256,"FlashAttention backward only supports head dimension at most 256");
if(head_size>192&&(head_size<=224||is_dropout)){
if(head_size>192&&is_dropout){
TORCH_CHECK(is_sm80||is_sm90,"FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
TORCH_CHECK(is_sm80||is_sm90,"FlashAttention backward for head dim > 192 with dropout requires A100/A800 or H100/H800");
}
}
TORCH_CHECK(num_heads%num_heads_k==0,"Number of heads in key/value must divide number of heads in query");
TORCH_CHECK(num_heads%num_heads_k==0,"Number of heads in key/value must divide number of heads in query");
@@ -1114,13 +1115,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -1114,13 +1115,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK(batch_size>0,"batch size must be positive");
TORCH_CHECK(batch_size>0,"batch size must be positive");
TORCH_CHECK(head_size%8==0,"head_size should be a multiple of 8");
TORCH_CHECK(head_size%8==0,"head_size should be a multiple of 8");
TORCH_CHECK(head_size<=256,"FlashAttention backward only supports head dimension at most 256");
TORCH_CHECK(head_size<=256,"FlashAttention backward only supports head dimension at most 256");
if(head_size>192&&(head_size<=224||is_dropout)){
if(head_size>192&&is_dropout){
TORCH_CHECK(is_sm80||is_sm90,"FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
TORCH_CHECK(is_sm80||is_sm90,"FlashAttention backward for head dim > 192 with dropout requires A100/A800 or H100/H800");
}
}
TORCH_CHECK(num_heads%num_heads_k==0,"Number of heads in key/value must divide number of heads in query");
TORCH_CHECK(num_heads%num_heads_k==0,"Number of heads in key/value must divide number of heads in query");
if(softcap>0.f){TORCH_CHECK(p_dropout==0.f,"Softcapping does not support dropout for now");}