TORCH_CHECK(q.stride(-1)==1,"Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1)==1,"Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1)==1,"Input tensor must have contiguous last dimension");
TORCH_CHECK(out.stride(-1)==1,"out tensor must have contiguous last dimension");
TORCH_CHECK(dout.stride(-1)==1,"dout tensor must have contiguous last dimension");
constautosizes=q.sizes();
constintbatch_size=sizes[0];
constintseqlen_q=sizes[1];
constintnum_heads=sizes[2];
constinthead_size_og=dout.size(3);
constinthead_size=sizes[3];
constintseqlen_k=k.size(1);
constintnum_heads_k=k.size(2);
TORCH_CHECK(batch_size>0,"batch size must be positive");
TORCH_CHECK(head_size%8==0,"head_size should be a multiple of 8");
TORCH_CHECK(head_size<=256,"FlashAttention backward only supports head dimension at most 256");
if(head_size>192&&(head_size<=224||is_dropout)){
TORCH_CHECK(is_sm80||is_sm90,"FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
}
TORCH_CHECK(num_heads%num_heads_k==0,"Number of heads in key/value must divide number of heads in query");
TORCH_CHECK(q.stride(-1)==1,"Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1)==1,"Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1)==1,"Input tensor must have contiguous last dimension");
TORCH_CHECK(out.stride(-1)==1,"out tensor must have contiguous last dimension");
TORCH_CHECK(dout.stride(-1)==1,"dout tensor must have contiguous last dimension");
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);
constautosizes=q.sizes();
constinttotal_q=sizes[0];
constintbatch_size=cu_seqlens_q.numel()-1;
constintnum_heads=sizes[1];
constinthead_size_og=dout.size(2);
constinthead_size=sizes[2];
constinttotal_k=k.size(0);
constintnum_heads_k=k.size(1);
TORCH_CHECK(batch_size>0,"batch size must be positive");
TORCH_CHECK(head_size%8==0,"head_size should be a multiple of 8");
TORCH_CHECK(head_size<=256,"FlashAttention backward only supports head dimension at most 256");
if(head_size>192&&(head_size<=224||is_dropout)){
TORCH_CHECK(is_sm80||is_sm90,"FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
}
TORCH_CHECK(num_heads%num_heads_k==0,"Number of heads in key/value must divide number of heads in query");
mha_fwd_kvcache(at::Tensor&q,// batch_size x seqlen_q x num_heads x head_size
constat::Tensor&kcache,// batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
...
...
@@ -1493,7 +932,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {