Unverified Commit 09e92454 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Add custom kernel for RMS normalization (#16)

parent c45f3c3a
import torch
import torch.nn as nn
from cacheflow import layernorm_ops
class RMSNorm(nn.Module):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
layernorm_ops.rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out
...@@ -12,6 +12,7 @@ from transformers import LlamaConfig ...@@ -12,6 +12,7 @@ from transformers import LlamaConfig
from cacheflow.models import InputMetadata from cacheflow.models import InputMetadata
from cacheflow.models.attention import LlamaCacheFlowAttention from cacheflow.models.attention import LlamaCacheFlowAttention
from cacheflow.models.layernorm import RMSNorm
from cacheflow.models.sample import Sampler from cacheflow.models.sample import Sampler
from cacheflow.parallel_utils.parallel_state import ( from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
...@@ -23,22 +24,6 @@ from cacheflow.sequence import SequenceOutputs ...@@ -23,22 +24,6 @@ from cacheflow.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
def __init__( def __init__(
...@@ -148,8 +133,8 @@ class LlamaDecoderLayer(nn.Module): ...@@ -148,8 +133,8 @@ class LlamaDecoderLayer(nn.Module):
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
) )
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward( def forward(
self, self,
...@@ -190,7 +175,7 @@ class LlamaModel(nn.Module): ...@@ -190,7 +175,7 @@ class LlamaModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
perform_initialization=False) perform_initialization=False)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward( def forward(
self, self,
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "attention_utils.h" #include "attention_utils.h"
#include "cuda_primitives.h" #include "cuda_primitives.h"
#include "reduction_utils.h"
#include <algorithm> #include <algorithm>
......
...@@ -159,45 +159,6 @@ struct Qk_dot<uint16_t, 4> { ...@@ -159,45 +159,6 @@ struct Qk_dot<uint16_t, 4> {
} }
}; };
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
inline __device__ float block_sum(float* red_smem, float sum)
{
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < WARPS_PER_BLOCK) {
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Broadcast to other threads.
return __shfl_sync(uint32_t(-1), sum, 0);
}
} // namespace cacheflow } // namespace cacheflow
#undef MMHA_USE_FP32_ACUM_FOR_FMA #undef MMHA_USE_FP32_ACUM_FOR_FMA
......
#include <torch/extension.h>
void rms_norm(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& weight,
float epsilon);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
}
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "reduction_utils.h"
namespace cacheflow {
// TODO(woosuk): Further optimize this kernel.
template<typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [num_tokens, hidden_size]
const scalar_t* __restrict__ input, // [num_tokens, hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float) input[blockIdx.x * hidden_size + idx];
variance += x * x;
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
}
}
} // namespace cacheflow
void rms_norm(
torch::Tensor& out, // [num_tokens, hidden_size]
torch::Tensor& input, // [num_tokens, hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
int num_tokens = input.size(0);
int hidden_size = input.size(1);
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(),
"rms_norm_kernel",
[&] {
cacheflow::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);
});
}
#pragma once
namespace cacheflow {
template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
inline __device__ float block_sum(float* red_smem, float sum)
{
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < WARPS_PER_BLOCK) {
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Broadcast to other threads.
return __shfl_sync(uint32_t(-1), sum, 0);
}
#define FINAL_MASK 0xffffffff
template<typename T>
__inline__ __device__ T warpReduceSum(T val)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
}
/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
} // namespace cacheflow
...@@ -31,6 +31,14 @@ positional_encoding_extension = cpp_extension.CUDAExtension( ...@@ -31,6 +31,14 @@ positional_encoding_extension = cpp_extension.CUDAExtension(
) )
ext_modules.append(positional_encoding_extension) ext_modules.append(positional_encoding_extension)
# Layer normalization kernels.
layernorm_extension = cpp_extension.CUDAExtension(
name='cacheflow.layernorm_ops',
sources=['csrc/layernorm.cpp', 'csrc/layernorm_kernels.cu'],
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
)
ext_modules.append(layernorm_extension)
setuptools.setup( setuptools.setup(
name='cacheflow', name='cacheflow',
ext_modules=ext_modules, ext_modules=ext_modules,
......
import torch
import torch.nn as nn
from cacheflow import layernorm_ops
class RefRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
weight = torch.randn(hidden_size) / (hidden_size ** 0.5)
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
@torch.inference_mode()
def test_rms_norm(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
) -> None:
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device='cuda')
ref = RefRMSNorm(hidden_size).to(dtype).cuda()
out = torch.empty_like(x)
layernorm_ops.rms_norm(
out,
x,
ref.weight.data,
ref.variance_epsilon,
)
ref_out = ref(x)
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5)
if __name__ == '__main__':
for dtype in [torch.half, torch.float]:
for num_tokens in [7, 128, 2048]:
for hidden_size in [13, 64, 1024, 5120]:
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
f'{num_tokens}, hidden_size={hidden_size}')
test_rms_norm(
num_tokens=num_tokens,
hidden_size=hidden_size,
dtype=dtype,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment