Commit 88964968 authored by Casper Hansen's avatar Casper Hansen
Browse files

Update kernel to original

parent 97af18e4
...@@ -94,10 +94,10 @@ void rotary_embedding( ...@@ -94,10 +94,10 @@ void rotary_embedding(
int head_size, int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) { bool is_neox) {
int num_tokens = query.size(0) * query.size(1); int num_tokens = query.size(0);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-2); int num_heads = query.size(1) / head_size;
int num_kv_heads = key.size(-2); int num_kv_heads = key.size(1) / head_size;
int query_stride = query.stride(0); int query_stride = query.stride(0);
int key_stride = key.stride(0); int key_stride = key.stride(0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment