@@ -494,12 +494,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -494,12 +494,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
std::vector<at::Tensor>
std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor&q,// total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
mha_varlen_fwd(at::Tensor&q,// total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
constat::Tensor&k,// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
constat::Tensor&k,// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
constat::Tensor&v,// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
constat::Tensor&v,// total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor>&out_,// total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor>&out_,// total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
constat::Tensor&cu_seqlens_q,// b+1
constat::Tensor&cu_seqlens_q,// b+1
constat::Tensor&cu_seqlens_k,// b+1
constat::Tensor&cu_seqlens_k,// b+1
c10::optional<at::Tensor>&seqused_k,// b. If given, only this many elements of each batch element's keys are used.
c10::optional<at::Tensor>&seqused_k,// b. If given, only this many elements of each batch element's keys are used.
c10::optional<at::Tensor>&block_table_,// batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor>&alibi_slopes_,// num_heads or b x num_heads
c10::optional<at::Tensor>&alibi_slopes_,// num_heads or b x num_heads
intmax_seqlen_q,
intmax_seqlen_q,
constintmax_seqlen_k,
constintmax_seqlen_k,
...
@@ -535,6 +536,15 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
...
@@ -535,6 +536,15 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
CHECK_DEVICE(cu_seqlens_q);
CHECK_DEVICE(cu_seqlens_q);
CHECK_DEVICE(cu_seqlens_k);
CHECK_DEVICE(cu_seqlens_k);
at::Tensorblock_table;
constboolpaged_KV=block_table_.has_value();
if(paged_KV){
block_table=block_table_.value();
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype()==torch::kInt32,"block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1)==1,"block_table must have contiguous last dimension");
}
TORCH_CHECK(q.stride(-1)==1,"Input tensor must have contiguous last dimension");
TORCH_CHECK(q.stride(-1)==1,"Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1)==1,"Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1)==1,"Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1)==1,"Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1)==1,"Input tensor must have contiguous last dimension");
...
@@ -546,8 +556,12 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
...
@@ -546,8 +556,12 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s