Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
76a7983b
Unverified
Commit
76a7983b
authored
Dec 17, 2023
by
Woosuk Kwon
Committed by
GitHub
Dec 17, 2023
Browse files
[BugFix] Fix RoPE kernel on long sequences(#2164)
parent
8041b730
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
6 deletions
+6
-6
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+6
-6
No files found.
csrc/pos_encoding_kernels.cu
View file @
76a7983b
...
@@ -43,8 +43,8 @@ __global__ void rotary_embedding_kernel(
...
@@ -43,8 +43,8 @@ __global__ void rotary_embedding_kernel(
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int
rot_dim
,
const
int
rot_dim
,
const
int
query_stride
,
const
int
64_t
query_stride
,
const
int
key_stride
,
const
int
64_t
key_stride
,
const
int
num_heads
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
const
int
head_size
)
{
...
@@ -60,7 +60,7 @@ __global__ void rotary_embedding_kernel(
...
@@ -60,7 +60,7 @@ __global__ void rotary_embedding_kernel(
const
int
nq
=
num_heads
*
embed_dim
;
const
int
nq
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int
64_t
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
sin_ptr
,
rot_offset
,
embed_dim
);
...
@@ -69,7 +69,7 @@ __global__ void rotary_embedding_kernel(
...
@@ -69,7 +69,7 @@ __global__ void rotary_embedding_kernel(
const
int
nk
=
num_kv_heads
*
embed_dim
;
const
int
nk
=
num_kv_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
sin_ptr
,
rot_offset
,
embed_dim
);
...
@@ -89,8 +89,8 @@ void rotary_embedding(
...
@@ -89,8 +89,8 @@ void rotary_embedding(
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
-
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
-
1
)
/
head_size
;
int
query_stride
=
query
.
stride
(
-
2
);
int
64_t
query_stride
=
query
.
stride
(
-
2
);
int
key_stride
=
key
.
stride
(
-
2
);
int
64_t
key_stride
=
key
.
stride
(
-
2
);
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment