Unverified Commit 27b50f1f authored by Thien Tran's avatar Thien Tran Committed by GitHub
Browse files

[Bugfix][Kernel][CPU] Fix num_tokens in CPU rotary embedding kernel (#14667)


Signed-off-by: default avatarThien Tran <gau.nernst@yahoo.com.sg>
parent 9532c498
......@@ -170,7 +170,7 @@ void rotary_embedding_gptj_impl(
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox) {
int num_tokens = query.numel() / query.size(-1);
int num_tokens = positions.numel();
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment