Unverified Commit 88c0268a authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Implement custom kernel for LLaMA rotary embedding (#14)

parent 80a2f812
...@@ -6,13 +6,14 @@ import torch.nn as nn ...@@ -6,13 +6,14 @@ import torch.nn as nn
from cacheflow import attention_ops from cacheflow import attention_ops
from cacheflow import cache_ops from cacheflow import cache_ops
from cacheflow import pos_encoding_ops
from cacheflow.models import InputMetadata from cacheflow.models import InputMetadata
class OPTCacheFlowAttention(nn.Module): class GPTCacheFlowAttention(nn.Module):
def __init__(self, scale: float) -> None: def __init__(self, scale: float) -> None:
super(OPTCacheFlowAttention, self).__init__() super().__init__()
self.scale = float(scale) self.scale = float(scale)
self.flash_attn = FlashAttention(softmax_scale=self.scale) self.flash_attn = FlashAttention(softmax_scale=self.scale)
...@@ -136,3 +137,71 @@ class OPTCacheFlowAttention(nn.Module): ...@@ -136,3 +137,71 @@ class OPTCacheFlowAttention(nn.Module):
# Reshape the output tensor. # Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings. # NOTE(woosuk): The output tensor may include paddings.
return output.view(-1, num_heads * head_size) return output.view(-1, num_heads * head_size)
class OPTCacheFlowAttention(GPTCacheFlowAttention):
"""OPT uses the same attention mechanism as GPT."""
def __init__(self, scale: float) -> None:
super().__init__(scale)
class LlamaCacheFlowAttention(GPTCacheFlowAttention):
"""Llama uses GPT-NeoX style rotary embedding."""
def __init__(
self,
scale: float,
head_size: int,
max_position: int = 8192,
base: int = 10000,
) -> None:
super().__init__(scale)
# Create the cos and sin cache.
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
# FIXME(woosuk): This assumes that we configure the default dtype when
# initializing the model. Make it more robust.
torch_dtype = torch.get_default_dtype()
cache = cache.to(torch_dtype)
# Embedding size: [max_position, head_size]
self.register_buffer('cos_sin_cache', cache, persistent=False)
def forward(
self,
positions: torch.LongTensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
out_query = torch.empty_like(query)
out_key = torch.empty_like(key)
pos_encoding_ops.rotary_embedding_neox(
out_query,
out_key,
positions,
query,
key,
self.cos_sin_cache,
)
return super().forward(
out_query,
out_key,
value,
key_cache,
value_cache,
input_metadata,
cache_event,
)
...@@ -8,12 +8,10 @@ from typing import Dict, List, Optional, Tuple ...@@ -8,12 +8,10 @@ from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F
from transformers import LlamaConfig from transformers import LlamaConfig
from transformers import PreTrainedModel
from cacheflow.models import InputMetadata from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention from cacheflow.models.attention import LlamaCacheFlowAttention
from cacheflow.models.sample import Sampler from cacheflow.models.sample import Sampler
from cacheflow.parallel_utils.parallel_state import ( from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
...@@ -41,48 +39,8 @@ class LlamaRMSNorm(nn.Module): ...@@ -41,48 +39,8 @@ class LlamaRMSNorm(nn.Module):
return self.weight * hidden_states return self.weight * hidden_states
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000):
super().__init__()
self.max_position_embeddings = max_position_embeddings
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
self.register_buffer("inv_freq", inv_freq)
# Create cos and sin embeddings.
t = torch.arange(max_position_embeddings).float()
freqs = torch.einsum("i,j->ij", t, self.inv_freq.float())
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype=self.inv_freq.dtype)
sin = emb.sin().to(dtype=self.inv_freq.dtype)
self.register_buffer("cos_cached", cos, persistent=False)
self.register_buffer("sin_cached", sin, persistent=False)
def forward(
self,
positions: torch.LongTensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
cos = F.embedding(positions, self.cos_cached)
sin = F.embedding(positions, self.sin_cached)
return cos, sin
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
# TODO: Optimize.
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
def __init__( def __init__(
self, self,
hidden_size: int, hidden_size: int,
...@@ -156,9 +114,7 @@ class LlamaAttention(nn.Module): ...@@ -156,9 +114,7 @@ class LlamaAttention(nn.Module):
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False, perform_initialization=False,
) )
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim) self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim)
# FIXME(woosuk): Rename this.
self.attn = OPTCacheFlowAttention(scale=self.scaling)
def forward( def forward(
self, self,
...@@ -171,19 +127,9 @@ class LlamaAttention(nn.Module): ...@@ -171,19 +127,9 @@ class LlamaAttention(nn.Module):
q, _ = self.q_proj(hidden_states) q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states) k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states) v, _ = self.v_proj(hidden_states)
k_cache, v_cache = kv_cache
# Apply rotrary embedding.
# TODO: Optimize.
q = q.view(-1, self.num_heads, self.head_dim).transpose(0, 1)
k = k.view(-1, self.num_heads, self.head_dim).transpose(0, 1)
cos, sin = self.rotary_emb(positions)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q = q.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim)
k = k.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim)
key_cache, value_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event) positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
...@@ -165,8 +165,7 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer): ...@@ -165,8 +165,7 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
self.head_size = config.hidden_size // self.num_heads self.head_size = config.hidden_size // self.num_heads
self.ffn_size = config.intermediate_size self.ffn_size = config.intermediate_size
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
# FIXME self.max_position = 8192
self.max_position = 2048
def _get_param_size(self) -> int: def _get_param_size(self) -> int:
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
......
...@@ -51,7 +51,7 @@ class OPTAttention(nn.Module): ...@@ -51,7 +51,7 @@ class OPTAttention(nn.Module):
assert num_heads % tensor_model_parallel_world_size == 0 assert num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = embed_dim // total_num_heads self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim ** -0.5
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer. # TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
self.k_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias, self.k_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
...@@ -66,7 +66,6 @@ class OPTAttention(nn.Module): ...@@ -66,7 +66,6 @@ class OPTAttention(nn.Module):
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.attn = OPTCacheFlowAttention(scale=self.scaling) self.attn = OPTCacheFlowAttention(scale=self.scaling)
def forward( def forward(
......
...@@ -12,7 +12,7 @@ from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_pa ...@@ -12,7 +12,7 @@ from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_pa
class Sampler(nn.Module): class Sampler(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super(Sampler, self).__init__() super().__init__()
def forward( def forward(
self, self,
......
...@@ -122,13 +122,13 @@ void reshape_and_cache( ...@@ -122,13 +122,13 @@ void reshape_and_cache(
torch::Tensor& value_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping) { torch::Tensor& slot_mapping) {
int num_tokens = key.size(0); int num_tokens = key.size(0);
int head_num = key.size(1); int num_heads = key.size(1);
int head_size = key.size(2); int head_size = key.size(2);
int block_size = key_cache.size(3); int block_size = key_cache.size(3);
int x = key_cache.size(4); int x = key_cache.size(4);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(head_num * head_size, 512)); dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
key.scalar_type(), key.scalar_type(),
...@@ -140,7 +140,7 @@ void reshape_and_cache( ...@@ -140,7 +140,7 @@ void reshape_and_cache(
key_cache.data_ptr<scalar_t>(), key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int>(), slot_mapping.data_ptr<int>(),
head_num, num_heads,
head_size, head_size,
block_size, block_size,
x); x);
......
#include <torch/extension.h>
void rotary_embedding_neox(
torch::Tensor& out_query,
torch::Tensor& out_key,
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
torch::Tensor& cos_sin_cache);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"rotary_embedding_neox",
&rotary_embedding_neox,
"Apply GPT-NeoX style rotary embedding to query and key");
}
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
namespace cacheflow {
template<typename scalar_t>
__global__ void rotary_embedding_neox_kernel(
scalar_t* __restrict__ out_query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ out_key, // [num_tokens, num_heads, head_size]
const int64_t* __restrict__ positions, // [num_tokens]
const scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, head_size // 2]
const int num_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * head_size;
const int embed_dim = head_size / 2;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int idx = token_idx * n + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int token_head = token_idx * n + head_idx * head_size;
const bool is_first_half = head_offset < embed_dim;
const int rot_offset = head_offset % embed_dim;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);
const scalar_t q_x = __ldg(query + token_head + x_index);
const scalar_t q_y = __ldg(query + token_head + y_index);
const scalar_t q_cos = is_first_half ? q_x : q_y;
const scalar_t q_sin = is_first_half ? -q_y : q_x;
out_query[idx] = q_cos * cos + q_sin * sin;
const scalar_t k_x = __ldg(key + token_head + x_index);
const scalar_t k_y = __ldg(key + token_head + y_index);
const scalar_t k_cos = is_first_half ? k_x : k_y;
const scalar_t k_sin = is_first_half ? -k_y : k_x;
out_key[idx] = k_cos * cos + k_sin * sin;
}
}
} // namespace cacheflow
void rotary_embedding_neox(
torch::Tensor& out_query, // [num_tokens, num_heads * head_size]
torch::Tensor& out_key, // [num_tokens, num_heads * head_size]
torch::Tensor& positions, // [num_tokens]
torch::Tensor& query, // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_heads * head_size]
torch::Tensor& cos_sin_cache) // [max_position, head_size]
{
int num_tokens = query.size(0);
int head_size = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size;
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
query.scalar_type(),
"rotary_embedding_neox",
[&] {
cacheflow::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
out_query.data_ptr<scalar_t>(),
out_key.data_ptr<scalar_t>(),
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
num_heads,
head_size);
});
}
...@@ -23,6 +23,14 @@ attention_extension = cpp_extension.CUDAExtension( ...@@ -23,6 +23,14 @@ attention_extension = cpp_extension.CUDAExtension(
) )
ext_modules.append(attention_extension) ext_modules.append(attention_extension)
# Positional encodings.
positional_encoding_extension = cpp_extension.CUDAExtension(
name='cacheflow.pos_encoding_ops',
sources=['csrc/pos_encoding.cpp', 'csrc/pos_encoding_kernels.cu'],
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
)
ext_modules.append(positional_encoding_extension)
setuptools.setup( setuptools.setup(
name='cacheflow', name='cacheflow',
ext_modules=ext_modules, ext_modules=ext_modules,
......
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from cacheflow import pos_encoding_ops
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class RefRotaryEmbeddingNeox(nn.Module):
"""Reference implementation of the GPT-NeoX style rotary embedding."""
def __init__(
self,
dim: int,
max_position_embeddings: int = 2048,
base: int = 10000,
) -> None:
super().__init__()
self.max_position_embeddings = max_position_embeddings
# Create cos and sin embeddings.
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
t = torch.arange(max_position_embeddings).float()
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype=inv_freq.dtype)
sin = emb.sin().to(dtype=inv_freq.dtype)
self.register_buffer("cos_cached", cos, persistent=False)
self.register_buffer("sin_cached", sin, persistent=False)
def forward(
self,
positions: torch.LongTensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads, head_size]
key: torch.Tensor, # [num_tokens, num_heads, head_size]
) -> Tuple[torch.Tensor, torch.Tensor]:
cos = F.embedding(positions, self.cos_cached)
sin = F.embedding(positions, self.sin_cached)
query = query.transpose(0, 1)
key = key.transpose(0, 1)
query, key = apply_rotary_pos_emb(query, key, cos, sin)
query = query.transpose(0, 1).contiguous()
key = key.transpose(0, 1).contiguous()
# Output query/key shape: [num_tokens, num_tokens, head_size]
return query, key
@torch.inference_mode()
def test_rotary_embedding_neox(
num_tokens: int,
num_heads: int,
head_size: int,
max_position: int,
dtype: torch.dtype,
base: int = 10000,
) -> None:
positions = torch.randint(0, max_position, (num_tokens,), device='cuda')
query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')
key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')
# Create the rotary embedding.
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
cos = freqs.cos()
sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1)
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
# Run the kernel.
out_query = torch.empty_like(query)
out_key = torch.empty_like(key)
pos_encoding_ops.rotary_embedding_neox(
out_query,
out_key,
positions,
query,
key,
cos_sin_cache,
)
# Run the reference implementation.
ref_rotary_embedding = RefRotaryEmbeddingNeox(
dim=head_size,
max_position_embeddings=max_position,
base=base,
).to(dtype=dtype, device='cuda')
ref_query, ref_key = ref_rotary_embedding(
positions,
query.view(num_tokens, num_heads, head_size),
key.view(num_tokens, num_heads, head_size),
)
ref_query = ref_query.view(num_tokens, num_heads * head_size)
ref_key = ref_key.view(num_tokens, num_heads * head_size)
# Compare the results.
assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
if __name__ == '__main__':
for dtype in [torch.half, torch.float]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Running tests for head_size={head_size} and dtype={dtype}')
test_rotary_embedding_neox(
num_tokens=2145,
num_heads=5,
head_size=head_size,
max_position=8192,
dtype=dtype,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment