Unverified Commit 76a7983b authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[BugFix] Fix RoPE kernel on long sequences(#2164)

parent 8041b730
...@@ -43,8 +43,8 @@ __global__ void rotary_embedding_kernel( ...@@ -43,8 +43,8 @@ __global__ void rotary_embedding_kernel(
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim, const int rot_dim,
const int query_stride, const int64_t query_stride,
const int key_stride, const int64_t key_stride,
const int num_heads, const int num_heads,
const int num_kv_heads, const int num_kv_heads,
const int head_size) { const int head_size) {
...@@ -60,7 +60,7 @@ __global__ void rotary_embedding_kernel( ...@@ -60,7 +60,7 @@ __global__ void rotary_embedding_kernel(
const int nq = num_heads * embed_dim; const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) { for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int token_head = token_idx * query_stride + head_idx * head_size; const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr, apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim); sin_ptr, rot_offset, embed_dim);
...@@ -69,7 +69,7 @@ __global__ void rotary_embedding_kernel( ...@@ -69,7 +69,7 @@ __global__ void rotary_embedding_kernel(
const int nk = num_kv_heads * embed_dim; const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) { for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int token_head = token_idx * key_stride + head_idx * head_size; const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr, apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim); sin_ptr, rot_offset, embed_dim);
...@@ -89,8 +89,8 @@ void rotary_embedding( ...@@ -89,8 +89,8 @@ void rotary_embedding(
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size; int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size; int num_kv_heads = key.size(-1) / head_size;
int query_stride = query.stride(-2); int64_t query_stride = query.stride(-2);
int key_stride = key.stride(-2); int64_t key_stride = key.stride(-2);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment