mha_fwd(at::Tensor&q,// batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
constat::Tensor&k,// batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
constat::Tensor&v,// batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
std::optional<at::Tensor>&out_,// batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
std::optional<at::Tensor>&alibi_slopes_,// num_heads or batch_size x num_heads
constfloatp_dropout,
constfloatsoftmax_scale,
boolis_causal,
intwindow_size_left,
intwindow_size_right,
constfloatsoftcap,
constboolreturn_softmax,
std::optional<at::Generator>gen_);
std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor&q,// total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
constat::Tensor&k,// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
constat::Tensor&v,// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
std::optional<at::Tensor>&out_,// total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
constat::Tensor&cu_seqlens_q,// b+1
constat::Tensor&cu_seqlens_k,// b+1
std::optional<at::Tensor>&seqused_k,// b. If given, only this many elements of each batch element's keys are used.
std::optional<at::Tensor>&block_table_,// batch_size x max_num_blocks_per_seq
std::optional<at::Tensor>&alibi_slopes_,// num_heads or b x num_heads
intmax_seqlen_q,
constintmax_seqlen_k,
constfloatp_dropout,
constfloatsoftmax_scale,
constboolzero_tensors,
boolis_causal,
intwindow_size_left,
intwindow_size_right,
constfloatsoftcap,
constboolreturn_softmax,
std::optional<at::Generator>gen_);
std::vector<at::Tensor>
mha_bwd(constat::Tensor&dout,// batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
constat::Tensor&q,// batch_size x seqlen_q x num_heads x head_size
constat::Tensor&k,// batch_size x seqlen_k x num_heads_k x head_size
constat::Tensor&v,// batch_size x seqlen_k x num_heads_k x head_size
constat::Tensor&out,// batch_size x seqlen_q x num_heads x head_size
constat::Tensor&softmax_lse,// b x h x seqlen_q
std::optional<at::Tensor>&dq_,// batch_size x seqlen_q x num_heads x head_size
std::optional<at::Tensor>&dk_,// batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor>&dv_,// batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor>&alibi_slopes_,// num_heads or batch_size x num_heads
constfloatp_dropout,// probability to drop
constfloatsoftmax_scale,
constboolis_causal,
intwindow_size_left,
intwindow_size_right,
constfloatsoftcap,
constbooldeterministic,
std::optional<at::Generator>gen_,
std::optional<at::Tensor>&rng_state);
std::vector<at::Tensor>
mha_varlen_bwd(constat::Tensor&dout,// total_q x num_heads, x head_size
constat::Tensor&q,// total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
constat::Tensor&k,// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
constat::Tensor&v,// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
constat::Tensor&out,// total_q x num_heads x head_size
constat::Tensor&softmax_lse,// h x total_q, softmax logsumexp
std::optional<at::Tensor>&dq_,// total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
std::optional<at::Tensor>&dk_,// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
std::optional<at::Tensor>&dv_,// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
constat::Tensor&cu_seqlens_q,// b+1
constat::Tensor&cu_seqlens_k,// b+1
std::optional<at::Tensor>&alibi_slopes_,// num_heads or b x num_heads
constintmax_seqlen_q,
constintmax_seqlen_k,// max sequence length to choose the kernel
constfloatp_dropout,// probability to drop
constfloatsoftmax_scale,
constboolzero_tensors,
constboolis_causal,
intwindow_size_left,
intwindow_size_right,
constfloatsoftcap,
constbooldeterministic,
std::optional<at::Generator>gen_,
std::optional<at::Tensor>&rng_state);
std::vector<at::Tensor>
mha_fwd_kvcache(at::Tensor&q,// batch_size x seqlen_q x num_heads x head_size
constat::Tensor&kcache,// batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
constat::Tensor&vcache,// batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
std::optional<constat::Tensor>&k_,// batch_size x seqlen_knew x num_heads_k x head_size
std::optional<constat::Tensor>&v_,// batch_size x seqlen_knew x num_heads_k x head_size