mha_fwd(Tensor&q,// batch_size x seqlen_q x num_heads x head_size
Tensor&k,// batch_size x seqlen_k x num_heads_k x head_size
Tensor&v,// batch_size x seqlen_k x num_heads_k x head_size
// c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
// c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
constfloatp_dropout,
constfloatsoftmax_scale,
boolis_causal,
intwindow_size_left,
intwindow_size_right,
constboolreturn_softmax
// c10::optional<at::Generator> gen_
);
std::vector<Tensor>
mha_varlen_fwd(Tensor&q,// total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
Tensor&k,// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
Tensor&v,// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
// std::optional<Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
Tensor&cu_seqlens_q,// b+1
Tensor&cu_seqlens_k,// b+1
// std::optional<Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
// std::optional<Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
// std::optional<Tensor> &alibi_slopes_, // num_heads or b x num_heads
intmax_seqlen_q,
constintmax_seqlen_k,
constfloatp_dropout,
constfloatsoftmax_scale,
constboolzero_tensors,
boolis_causal,
intwindow_size_left,
intwindow_size_right,
constboolreturn_softmax);
std::vector<Tensor>
mha_fwd_block(constTensor&q,// total_q x num_heads x head_size, total := \sum_{i=0}^{b} s_i
constTensor&k,// total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
constTensor&v,// total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
mha_fwd(at::Tensor&q,// batch_size x seqlen_q x num_heads x head_size
constat::Tensor&k,// batch_size x seqlen_k x num_heads_k x head_size
constat::Tensor&v,// batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor>&out_,// batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor>&alibi_slopes_,// num_heads or batch_size x num_heads
constfloatp_dropout,
constfloatsoftmax_scale,
boolis_causal,
intwindow_size_left,
intwindow_size_right,
constboolreturn_softmax,
c10::optional<at::Generator>gen_);
std::vector<Tensor>
mha_fwd(Tensor&q,// batch_size x seqlen_q x num_heads x head_size
Tensor&k,// batch_size x seqlen_k x num_heads_k x head_size
Tensor&v,// batch_size x seqlen_k x num_heads_k x head_size
// c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
// c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
constfloatp_dropout,
constfloatsoftmax_scale,
boolis_causal,
intwindow_size_left,
intwindow_size_right,
constboolreturn_softmax
// c10::optional<at::Generator> gen_
)
{
std::optional<Tensor>out_={};
std::optional<Tensor>alibi_slopes_={};
returnmha_fwd(
q,k,v,
out_,alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
return_softmax,
{}
);
}
std::vector<at::Tensor>
mha_varlen_fwd(constat::Tensor&q,// total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
constat::Tensor&k,// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
constat::Tensor&v,// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor>&out_,// total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
at::Tensor&cu_seqlens_q,// b+1
at::Tensor&cu_seqlens_k,// b+1
c10::optional<at::Tensor>&seqused_k,// b. If given, only this many elements of each batch element's keys are used.
c10::optional<at::Tensor>&alibi_slopes_,// num_heads or b x num_heads
constintmax_seqlen_q,
constintmax_seqlen_k,
constfloatp_dropout,
constfloatsoftmax_scale,
constboolzero_tensors,
constboolis_causal,
intwindow_size_left,
intwindow_size_right,
constboolreturn_softmax,
c10::optional<at::Generator>gen_);
std::vector<Tensor>
mha_varlen_fwd(Tensor&q,// total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
Tensor&k,// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
Tensor&v,// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
// std::optional<Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
Tensor&cu_seqlens_q,// b+1
Tensor&cu_seqlens_k,// b+1
// std::optional<Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
// std::optional<Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
// std::optional<Tensor> &alibi_slopes_, // num_heads or b x num_heads
intmax_seqlen_q,
constintmax_seqlen_k,
constfloatp_dropout,
constfloatsoftmax_scale,
constboolzero_tensors,
boolis_causal,
intwindow_size_left,
intwindow_size_right,
constboolreturn_softmax)
{
std::optional<Tensor>out_={};
std::optional<Tensor>seqused_k={};
std::optional<Tensor>alibi_slopes_={};
returnmha_varlen_fwd(
q,k,v,
out_,
cu_seqlens_q,cu_seqlens_k,
seqused_k,alibi_slopes_,
max_seqlen_q,max_seqlen_k,
p_dropout,softmax_scale,zero_tensors,is_causal,
window_size_left,window_size_right,
return_softmax,
{}
);
}
std::vector<at::Tensor>
mha_fwd_block(constat::Tensor&q,
// total_q x num_heads x head_size, total := \sum_{i=0}^{b} s_i
constat::Tensor&k,
// total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
constat::Tensor&v,
// total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i