"vscode:/vscode.git/clone" did not exist on "f89978ad7cf3732a4a482db8e784aba7e417abbd"
Commit c628c6ec authored by huangwb's avatar huangwb
Browse files

Merge branch 'vllm-v0.5.0-dtk24.04.1-rotary_embedding' into 'vllm-v0.5.0-dtk24.04.1'

add rotary_embedding for tgi

See merge request dcutoolkit/deeplearing/vllm!3
parents 22839191 e499f96c
...@@ -150,6 +150,7 @@ set(VLLM_EXT_SRC ...@@ -150,6 +150,7 @@ set(VLLM_EXT_SRC
"csrc/cache_kernels.cu" "csrc/cache_kernels.cu"
"csrc/attention/attention_kernels.cu" "csrc/attention/attention_kernels.cu"
"csrc/pos_encoding_kernels.cu" "csrc/pos_encoding_kernels.cu"
"csrc/pos_encoding_tgi_kernels.cu"
"csrc/activation_kernels.cu" "csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu" "csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu"
......
...@@ -38,6 +38,13 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, ...@@ -38,6 +38,13 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& cos_sin_cache, bool is_neox, torch::Tensor& cos_sin_cache, bool is_neox,
int64_t rot_dim, int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets); torch::Tensor& cos_sin_cache_offsets);
void rotary_embedding_tgi(
torch::Tensor& query,
torch::Tensor& key,
int64_t head_size,
torch::Tensor& cos_cache,
torch::Tensor& sin_cache,
bool is_neox);
void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
......
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
namespace vllm {
template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding_tgi(
scalar_t* __restrict__ arr,
const float* __restrict__ cos_ptr,
const float* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index;
float cos, sin;
if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = VLLM_LDG(cos_ptr + x_index);
sin = VLLM_LDG(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = VLLM_LDG(cos_ptr + x_index / 2);
sin = VLLM_LDG(sin_ptr + x_index / 2);
}
const scalar_t x = arr[x_index];
const scalar_t y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}
template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding_tgi(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const float* __restrict__ cos_ptr, // [max_position, 1, rot_dim]
const float* __restrict__ sin_ptr, // [max_position, 1, rot_dim]
const int head_size,
const int num_heads,
const int num_kv_heads,
const int rot_dim,
const int token_idx,
const int64_t query_stride,
const int64_t key_stride)
{
const int nq = num_heads * rot_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / rot_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % rot_dim;
apply_token_rotary_embedding_tgi<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
sin_ptr, rot_offset, rot_dim);
}
const int nk = num_kv_heads * rot_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / rot_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % rot_dim;
apply_token_rotary_embedding_tgi<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
sin_ptr, rot_offset, rot_dim);
}
}
template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_tgi_kernel(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const float* __restrict__ cos_cache, // [max_position, 1, rot_dim]
const float* __restrict__ sin_cache, // [max_position, 1, rot_dim]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
const float* cos_ptr = cos_cache + token_idx * rot_dim;
const float* sin_ptr = sin_cache + token_idx * rot_dim;
apply_rotary_embedding_tgi<scalar_t, IS_NEOX>(query, key, cos_ptr, sin_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
}
} // namespace vllm
void rotary_embedding_tgi(
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
int64_t head_size,
torch::Tensor& cos_cache,
torch::Tensor& sin_cache,
bool is_neox) {
int num_tokens = query.size(0);
int rot_dim = cos_cache.size(2);
int num_heads = query.size(1);
int num_kv_heads = key.size(1);
int query_stride = query.stride(0);
int key_stride = key.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(),
"rotary_embedding_tgi",
[&] {
if (is_neox) {
vllm::rotary_embedding_tgi_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_cache.data_ptr<float>(),
sin_cache.data_ptr<float>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
vllm::rotary_embedding_tgi_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_cache.data_ptr<float>(),
sin_cache.data_ptr<float>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
}
...@@ -89,6 +89,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -89,6 +89,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache, bool is_neox) -> ()"); " Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
// Rotary embedding TGI for TGI
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding_tgi(Tensor! query, Tensor! key,"
" int head_size, Tensor cos_cache,"
" Tensor sin_cache, bool is_neox) -> ()");
// ops.def("rotary_embedding_tgi",&rotary_embedding_tgi);
ops.impl("rotary_embedding_tgi", torch::kCUDA, &rotary_embedding_tgi);
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key // Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// (supports multiple loras). // (supports multiple loras).
ops.def( ops.def(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment