@@ -413,8 +413,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
constinttotal_k=k.size(TOTAL_DIM);
TORCH_CHECK(batch_size>0);
TORCH_CHECK((head_size%8==0)&&(head_size<=128));
if(head_size>64){// TODO: eventually we should support SM86 and SM70 with d=128 as well
TORCH_CHECK(is_sm80||is_sm90);
if(head_size>64){
TORCH_CHECK(is_sm80||is_sm90,"FlashAttention backward for head dim > 64 requires A100 or H100 GPUs as the implementation needs a large amount of shared memory.");