Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e499f96c
Commit
e499f96c
authored
Jul 17, 2024
by
huangwb
Browse files
add rotary_embedding for tgi
parent
22839191
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
161 additions
and
0 deletions
+161
-0
CMakeLists.txt
CMakeLists.txt
+1
-0
csrc/ops.h
csrc/ops.h
+7
-0
csrc/pos_encoding_tgi_kernels.cu
csrc/pos_encoding_tgi_kernels.cu
+144
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+9
-0
No files found.
CMakeLists.txt
View file @
e499f96c
...
@@ -150,6 +150,7 @@ set(VLLM_EXT_SRC
...
@@ -150,6 +150,7 @@ set(VLLM_EXT_SRC
"csrc/cache_kernels.cu"
"csrc/cache_kernels.cu"
"csrc/attention/attention_kernels.cu"
"csrc/attention/attention_kernels.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/pos_encoding_tgi_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
...
...
csrc/ops.h
View file @
e499f96c
...
@@ -38,6 +38,13 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
...
@@ -38,6 +38,13 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
int64_t
rot_dim
,
int64_t
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
);
torch
::
Tensor
&
cos_sin_cache_offsets
);
void
rotary_embedding_tgi
(
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_cache
,
torch
::
Tensor
&
sin_cache
,
bool
is_neox
);
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
...
csrc/pos_encoding_tgi_kernels.cu
0 → 100644
View file @
e499f96c
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
namespace
vllm
{
template
<
typename
scalar_t
,
bool
IS_NEOX
>
inline
__device__
void
apply_token_rotary_embedding_tgi
(
scalar_t
*
__restrict__
arr
,
const
float
*
__restrict__
cos_ptr
,
const
float
*
__restrict__
sin_ptr
,
int
rot_offset
,
int
embed_dim
)
{
int
x_index
,
y_index
;
float
cos
,
sin
;
if
(
IS_NEOX
)
{
// GPT-NeoX style rotary embedding.
x_index
=
rot_offset
;
y_index
=
embed_dim
+
rot_offset
;
cos
=
VLLM_LDG
(
cos_ptr
+
x_index
);
sin
=
VLLM_LDG
(
sin_ptr
+
x_index
);
}
else
{
// GPT-J style rotary embedding.
x_index
=
2
*
rot_offset
;
y_index
=
2
*
rot_offset
+
1
;
cos
=
VLLM_LDG
(
cos_ptr
+
x_index
/
2
);
sin
=
VLLM_LDG
(
sin_ptr
+
x_index
/
2
);
}
const
scalar_t
x
=
arr
[
x_index
];
const
scalar_t
y
=
arr
[
y_index
];
arr
[
x_index
]
=
x
*
cos
-
y
*
sin
;
arr
[
y_index
]
=
y
*
cos
+
x
*
sin
;
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
inline
__device__
void
apply_rotary_embedding_tgi
(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const
float
*
__restrict__
cos_ptr
,
// [max_position, 1, rot_dim]
const
float
*
__restrict__
sin_ptr
,
// [max_position, 1, rot_dim]
const
int
head_size
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
rot_dim
,
const
int
token_idx
,
const
int64_t
query_stride
,
const
int64_t
key_stride
)
{
const
int
nq
=
num_heads
*
rot_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
rot_dim
;
const
int64_t
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
rot_dim
;
apply_token_rotary_embedding_tgi
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
rot_dim
);
}
const
int
nk
=
num_kv_heads
*
rot_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
rot_dim
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
rot_dim
;
apply_token_rotary_embedding_tgi
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
rot_dim
);
}
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
rotary_embedding_tgi_kernel
(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const
float
*
__restrict__
cos_cache
,
// [max_position, 1, rot_dim]
const
float
*
__restrict__
sin_cache
,
// [max_position, 1, rot_dim]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
const
float
*
cos_ptr
=
cos_cache
+
token_idx
*
rot_dim
;
const
float
*
sin_ptr
=
sin_cache
+
token_idx
*
rot_dim
;
apply_rotary_embedding_tgi
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cos_ptr
,
sin_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
}
}
// namespace vllm
void
rotary_embedding_tgi
(
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
int64_t
head_size
,
torch
::
Tensor
&
cos_cache
,
torch
::
Tensor
&
sin_cache
,
bool
is_neox
)
{
int
num_tokens
=
query
.
size
(
0
);
int
rot_dim
=
cos_cache
.
size
(
2
);
int
num_heads
=
query
.
size
(
1
);
int
num_kv_heads
=
key
.
size
(
1
);
int
query_stride
=
query
.
stride
(
0
);
int
key_stride
=
key
.
stride
(
0
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"rotary_embedding_tgi"
,
[
&
]
{
if
(
is_neox
)
{
vllm
::
rotary_embedding_tgi_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_cache
.
data_ptr
<
float
>
(),
sin_cache
.
data_ptr
<
float
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
vllm
::
rotary_embedding_tgi_kernel
<
scalar_t
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_cache
.
data_ptr
<
float
>
(),
sin_cache
.
data_ptr
<
float
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
}
csrc/torch_bindings.cpp
View file @
e499f96c
...
@@ -89,6 +89,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -89,6 +89,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache, bool is_neox) -> ()"
);
" Tensor cos_sin_cache, bool is_neox) -> ()"
);
ops
.
impl
(
"rotary_embedding"
,
torch
::
kCUDA
,
&
rotary_embedding
);
ops
.
impl
(
"rotary_embedding"
,
torch
::
kCUDA
,
&
rotary_embedding
);
// Rotary embedding TGI for TGI
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops
.
def
(
"rotary_embedding_tgi(Tensor! query, Tensor! key,"
" int head_size, Tensor cos_cache,"
" Tensor sin_cache, bool is_neox) -> ()"
);
// ops.def("rotary_embedding_tgi",&rotary_embedding_tgi);
ops
.
impl
(
"rotary_embedding_tgi"
,
torch
::
kCUDA
,
&
rotary_embedding_tgi
);
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// (supports multiple loras).
// (supports multiple loras).
ops
.
def
(
ops
.
def
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment