Commit 97af18e4 authored by Casper Hansen's avatar Casper Hansen
Browse files

Update rotary embedding kernel

parent 54b63712
#pragma once #pragma once
#include <torch/extension.h> #include <torch/extension.h>
void rotary_embedding_neox( void rotary_embedding(
torch::Tensor& positions, torch::Tensor& positions,
torch::Tensor& query, torch::Tensor& query,
torch::Tensor& key, torch::Tensor& key,
int head_size, int head_size,
torch::Tensor& cos_sin_cache); torch::Tensor& cos_sin_cache,
\ No newline at end of file bool is_neox);
\ No newline at end of file
...@@ -9,15 +9,56 @@ https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu ...@@ -9,15 +9,56 @@ https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "pos_encoding.h" #include "pos_encoding.h"
template<typename scalar_t> #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
__global__ void rotary_embedding_neox_kernel( AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ arr,
const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index;
scalar_t cos, sin;
if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = __ldg(cos_ptr + x_index);
sin = __ldg(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = __ldg(cos_ptr + x_index / 2);
sin = __ldg(sin_ptr + x_index / 2);
}
const scalar_t x = arr[x_index];
const scalar_t y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}
template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [num_tokens] const int64_t* __restrict__ positions, // [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim, const int rot_dim,
const int stride, const int query_stride,
const int key_stride,
const int num_heads, const int num_heads,
const int num_kv_heads,
const int head_size) { const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
...@@ -25,64 +66,72 @@ __global__ void rotary_embedding_neox_kernel( ...@@ -25,64 +66,72 @@ __global__ void rotary_embedding_neox_kernel(
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
const int n = num_heads * embed_dim; const scalar_t* cos_ptr = cache_ptr;
for (int i = threadIdx.x; i < n; i += blockDim.x) { const scalar_t* sin_ptr = cache_ptr + embed_dim;
const int head_idx = i / embed_dim;
const int token_head = token_idx * stride + head_idx * head_size;
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
const int x_index = rot_offset; apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
const int y_index = embed_dim + rot_offset; sin_ptr, rot_offset, embed_dim);
}
const int out_x = token_idx * stride + head_idx * head_size + x_index;
const int out_y = token_idx * stride + head_idx * head_size + y_index;
const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);
const scalar_t q_x = query[token_head + x_index];
const scalar_t q_y = query[token_head + y_index];
query[out_x] = q_x * cos - q_y * sin;
query[out_y] = q_y * cos + q_x * sin;
const scalar_t k_x = key[token_head + x_index]; const int nk = num_kv_heads * embed_dim;
const scalar_t k_y = key[token_head + y_index]; for (int i = threadIdx.x; i < nk; i += blockDim.x) {
key[out_x] = k_x * cos - k_y * sin; const int head_idx = i / embed_dim;
key[out_y] = k_y * cos + k_x * sin; const int token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
} }
} }
void rotary_embedding(
void rotary_embedding_neox( torch::Tensor& positions, // [num_tokens]
torch::Tensor& positions, // [b, num_tokens] torch::Tensor& query, // [num_tokens, num_heads * head_size]
torch::Tensor& query, // [b, num_tokens, 1, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
torch::Tensor& key, // [b, num_tokens, 1, num_heads, head_size]
int head_size, int head_size,
torch::Tensor& cos_sin_cache) // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
{ bool is_neox) {
int num_tokens = query.size(0) * query.size(1); int num_tokens = query.size(0) * query.size(1);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-2); int num_heads = query.size(-2);
int stride = num_heads * head_size; int num_kv_heads = key.size(-2);
// TORCH_CHECK(stride == key.stride(0)); int query_stride = query.stride(0);
int key_stride = key.stride(0);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2( VLLM_DISPATCH_FLOATING_TYPES(
at::ScalarType::Half,
at::ScalarType::BFloat16,
query.scalar_type(), query.scalar_type(),
"rotary_embedding_neox", "rotary_embedding",
[&] { [&] {
rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>( if (is_neox) {
positions.data_ptr<int64_t>(), rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(),
key.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(),
rot_dim, cos_sin_cache.data_ptr<scalar_t>(),
stride, rot_dim,
num_heads, query_stride,
head_size); key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
}); });
} }
\ No newline at end of file
...@@ -8,5 +8,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ...@@ -8,5 +8,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{ {
m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel"); m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel");
m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel."); m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel.");
m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key"); m.def("rotary_embedding", &rotary_embedding, "Apply rotary embedding to query and key");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment