Unverified Commit e5da6e4d authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Fix out kwarg shape check with ngroups swapped (#4)

parent 03bf1f8a
...@@ -726,7 +726,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -726,7 +726,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
const int batch_size = sizes[0]; const int batch_size = sizes[0];
int seqlen_q = sizes[1]; int seqlen_q = sizes[1];
const int seqlen_q_og = seqlen_q;
int num_heads = sizes[2]; int num_heads = sizes[2];
const int num_heads_og = num_heads;
const int head_size_og = sizes[3]; const int head_size_og = sizes[3];
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
...@@ -784,8 +786,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -784,8 +786,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out); CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(out, batch_size, seqlen_q_og, num_heads_og, head_size_og);
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } if (head_size_og % 8 != 0) {
out = torch::empty_like(q_padded);
} else if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads, seqlen_q, head_size_og}).transpose(1, 2);
}
} else { } else {
out = torch::empty_like(q_padded); out = torch::empty_like(q_padded);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment