Commit 90ddfba8 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix fa error and remove layernorm kernel

parent 7f7894c0
......@@ -297,7 +297,6 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/opt/transpose_kernels.cu"
"csrc/opt/activation_kernels_opt.cu"
# "csrc/opt/layernorm_kernels_opt.cu"
"csrc/fused_qknorm_rope_kernel.cu"
# "csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu"
......
#include "type_convert.cuh"
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const int64_t input_stride,
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * input_stride + idx];
variance += x * x;
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * input_stride + idx];
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
}
}
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
const int64_t input_stride,
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
const int vec_hidden_size = hidden_size / width;
const int64_t vec_input_stride = input_stride / width;
__shared__ float s_variance;
float variance = 0.0f;
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto* __restrict__ input_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ residual_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
int64_t strided_id = blockIdx.x * vec_input_stride + idx;
_f16Vec<scalar_t, width> temp = input_v[strided_id];
temp += residual_v[id];
variance += temp.sum_squares();
residual_v[id] = temp;
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
int64_t strided_id = blockIdx.x * vec_input_stride + idx;
_f16Vec<scalar_t, width> temp = residual_v[id];
temp *= s_variance;
temp *= weight_v[idx];
input_v[strided_id] = temp;
}
}
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
const int64_t input_stride,
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
scalar_t z = input[blockIdx.x * input_stride + idx];
z += residual[blockIdx.x * hidden_size + idx];
float x = (float)z;
variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z;
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * input_stride + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
}
}
} // namespace vllm
template <typename T,int reducesize=C10_WARP_SIZE>
__inline__ __device__ T WarpReduceSum_NEW(T val) {
#pragma unroll
for (int offset = reducesize/2; offset > 0; offset >>= 1) {
val += WARP_SHFL_DOWN(val, offset);
}
return val;
}
template <typename T,int block_size=512>
__inline__ __device__ T BlockReduceSum_NEW(T val, T* shared) {
constexpr int share_size=block_size/C10_WARP_SIZE;
val = WarpReduceSum_NEW<T>(val);
if constexpr(block_size==C10_WARP_SIZE)
{
return val;
}
else{
const int lid = threadIdx.x % C10_WARP_SIZE;
const int wid = threadIdx.x / C10_WARP_SIZE;
if (lid == 0&&wid<share_size) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0&&lid<share_size) {
val = WarpReduceSum_NEW<T,share_size>(shared[lid]);
}
return val;
}
}
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_add_rms_kernel_opt(scalar_t* input,scalar_t* residual,scalar_t* gamma,int cols,T_ACC eps)
{
constexpr int share_size=block_size/C10_WARP_SIZE;
__shared__ T_ACC val_shared[share_size];
__shared__ T_ACC s_rstd;
T_ACC val=0;
int i=blockIdx.x;
int j=threadIdx.x;
int tcol=cols/Vec;
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec];
scalar_t residual_vec[Vec];
T_ACC trstd;
int64_t idx = i * tcol + j;
idx*=Vec;
if (j < tcol) {
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
*(LoadT*)residual_vec = *(LoadT*)(residual+idx);
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
residual_vec[ii]+=intput_vec[ii];
val += static_cast<T_ACC>(residual_vec[ii])*static_cast<T_ACC>(residual_vec[ii]);
}
}
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
__syncthreads();
trstd=s_rstd;
if (j < tcol) {
#pragma unroll
for(int ii=0;ii<Vec;ii++){
int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(residual_vec[ii]) *trstd* static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(residual+idx)=*(LoadT*)residual_vec;
*(LoadT*)(input+idx)=*(LoadT*)intput_vec;
}
}
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_rms_kernel_opt(scalar_t* input,scalar_t* output,scalar_t* gamma,int cols,T_ACC eps)
{
constexpr int share_size=block_size/C10_WARP_SIZE;
__shared__ T_ACC val_shared[share_size];
__shared__ T_ACC s_rstd;
T_ACC val=0;
int i=blockIdx.x;
int j=threadIdx.x;
int tcol=cols/Vec;
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec];
T_ACC trstd;
int64_t idx = i * tcol + j;
idx*=Vec;
if (j < tcol) {
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
val += static_cast<T_ACC>(intput_vec[ii])*static_cast<T_ACC>(intput_vec[ii]);
}
}
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
__syncthreads();
trstd=s_rstd;
if (j < tcol) {
#pragma unroll
for(int ii=0;ii<Vec;ii++){
int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(intput_vec[ii]) *trstd* static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(output+idx)=*(LoadT*)intput_vec;
}
}
void rms_norm_opt(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.stride(-1) == 1);
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
int64_t input_stride = input.stride(-2);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =inp_ptr % 16 == 0 && wt_ptr % 16 == 0;
if(hidden_size%16==0&&hidden_size<=16384&&ptrs_are_aligned){
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"fused_add_rms_norm_kernel",
[&] {
using T_ACC = at::acc_type<scalar_t, true>;
T_ACC eps = epsilon;
scalar_t* self_data = input.expect_contiguous()->data_ptr<scalar_t>();
scalar_t* out_data = out.expect_contiguous()->data_ptr<scalar_t>();
scalar_t* weight_data= weight.expect_contiguous()->data_ptr<scalar_t>();
if (hidden_size<=1024){
fused_rms_kernel_opt<scalar_t,T_ACC,8,128><<<num_tokens, 128, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=2048){
fused_rms_kernel_opt<scalar_t,T_ACC,8,256><<<num_tokens, 256, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=4096){
if(num_tokens>1200){
fused_rms_kernel_opt<scalar_t,T_ACC,8,512><<<num_tokens, 512, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else{
fused_rms_kernel_opt<scalar_t,T_ACC,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
}
else if(hidden_size<=8192){
fused_rms_kernel_opt<scalar_t,T_ACC,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else{
fused_rms_kernel_opt<scalar_t,T_ACC,16,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
});
}
else{
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
});
}
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t>(), input_stride, \
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
epsilon, num_tokens, hidden_size); \
});
void fused_add_rms_norm_opt(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1);
int64_t input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size;
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if(hidden_size%16==0&&hidden_size<=16384&&ptrs_are_aligned){
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"fused_add_rms_norm_kernel",
[&] {
using T_ACC = at::acc_type<scalar_t, true>;
T_ACC eps = epsilon;
scalar_t* self_data = input.expect_contiguous()->data_ptr<scalar_t>();
scalar_t* other_data = residual.expect_contiguous()->data_ptr<scalar_t>();
scalar_t* weight_data= weight.expect_contiguous()->data_ptr<scalar_t>();
if (hidden_size<=1024){
fused_add_rms_kernel_opt<scalar_t,T_ACC,8,128><<<num_tokens, 128, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=2048){
fused_add_rms_kernel_opt<scalar_t,T_ACC,8,256><<<num_tokens, 256, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=4096){
if(num_tokens>1200){
fused_add_rms_kernel_opt<scalar_t,T_ACC,8,512><<<num_tokens, 512, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else{
fused_add_rms_kernel_opt<scalar_t,T_ACC,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
}
else if(hidden_size<=8192){
fused_add_rms_kernel_opt<scalar_t,T_ACC,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else{
fused_add_rms_kernel_opt<scalar_t,T_ACC,16,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
});
}
else{
dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
constexpr int vector_width = 8;
constexpr int req_alignment_bytes =
vector_width * 2; // vector_width * sizeof(bfloat16 or float16) (float32
// falls back to non-vectorized version anyway)
bool ptrs_are_aligned = inp_ptr % req_alignment_bytes == 0 &&
res_ptr % req_alignment_bytes == 0 &&
wt_ptr % req_alignment_bytes == 0;
bool offsets_are_multiple_of_vector_width =
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);
}
}
}
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, Optional
import torch
......@@ -348,17 +348,6 @@ def fused_add_rms_norm(
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
# layer norm ops (opt)
# def rms_norm_opt(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
# epsilon: float) -> None:
# torch.ops._C.rms_norm_opt(out, input, weight, epsilon)
# def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
# weight: torch.Tensor, epsilon: float) -> None:
# torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon)
def fused_qk_norm_rope(
qkv: torch.Tensor,
num_heads_q: int,
......
......@@ -15,6 +15,7 @@ from vllm.model_executor.layers.batch_invariant import (
)
from vllm.platforms import current_platform
from vllm import envs
import lightop as op
def rms_norm(
......@@ -27,10 +28,10 @@ def rms_norm(
out = torch.empty_like(x)
# if envs.VLLM_USE_OPT_OP:
if False:
ops.rms_norm_opt(
out,
op.rmsnorm_forward(
x,
weight,
out,
variance_epsilon,
)
else:
......@@ -57,7 +58,7 @@ def fused_add_rms_norm(
), x + residual
# if envs.VLLM_USE_OPT_OP:
if False:
ops.fused_add_rms_norm_opt(
op.rn_add_forward_autograd(
x,
residual,
weight,
......
......@@ -22,10 +22,8 @@ elif current_platform.is_xpu():
elif current_platform.is_rocm():
try:
# from flash_attn import flash_attn_varlen_func # noqa: F401
from vllm import _custom_ops as ops
from vllm._custom_ops import reshape_and_cache_cuda
from flash_attn import vllm_flash_attn_varlen_func
reshape_and_cache_cuda = ops.reshape_and_cache_cuda
except ImportError as e:
raise ImportError(
"Rocm platform requires upstream flash-attn "
......@@ -41,7 +39,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
return 2
if current_platform.is_rocm():
# ROCm doesn't use vllm_flash_attn; return None to skip fa_version arg
return None
return 2 # None
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason,
......
......@@ -742,7 +742,7 @@ class FlashAttentionImpl(AttentionImpl):
layer._v_scale
)
else:
from vllm.attention.utils.fa_utils import reshape_and_cache_cuda
from vllm.v1.attention.backends.fa_utils import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment