"tests/vscode:/vscode.git/clone" did not exist on "b4a2f3ac369043b4a734160215575f2bc8037678"
Commit 4405f82c authored by zhuwenwen's avatar zhuwenwen
Browse files

Refactoring the optimized kernel

parent 3a6764a4
......@@ -155,7 +155,10 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_tgi_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/transpose_kernels.cu"
"csrc/opt/transpose_kernels.cu"
"csrc/opt/activation_kernels_opt.cu"
"csrc/opt/attention_kernels_opt.cu"
"csrc/opt/layernorm_kernels_opt.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
......
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <cmath>
......@@ -24,60 +23,6 @@ __global__ void act_and_mul_kernel(
}
}
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC>
__global__ void act_and_mul_kernel_vectorize1(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
const int64_t token_idx= blockIdx.x;
int idx = threadIdx.x * VEC;
if (idx < d) {
const int64_t x_index = token_idx * 2 * d + idx;
const int64_t y_index = token_idx * d + idx;
VecType* x1 = (VecType*)(input + x_index);
VecType* x2 = (VecType*)(input + x_index + d);
VecType* y = (VecType*)(out + y_index);
scalar_t r_x1[VEC];
scalar_t r_x2[VEC];
scalar_t r_y[VEC];
*(VecType*)r_x1 = *x1;
*(VecType*)r_x2 = *x2;
#pragma unroll
for (int i = 0; i < VEC; i++) {
r_y[i] = ACT_FN(r_x1[i]) * r_x2[i];
}
*y = *(VecType*)r_y;
}
}
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC>
__global__ void act_and_mul_kernel_vectorize2(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
const int64_t token_idx = blockIdx.x;
int idx = threadIdx.x * VEC;
for (; idx < d; idx += blockDim.x * VEC) {
const int64_t x_index = token_idx * 2 * d + idx;
const int64_t y_index = token_idx * d + idx;
VecType* x1 = (VecType*)(input + x_index);
VecType* x2 = (VecType*)(input + x_index + d);
VecType* y = (VecType*)(out + y_index);
scalar_t r_x1[VEC];
scalar_t r_x2[VEC];
scalar_t r_y[VEC];
*(VecType*)r_x1 = *x1;
*(VecType*)r_x2 = *x2;
#pragma unroll
for (int i = 0; i < VEC; i++) {
r_y[i] = ACT_FN(r_x1[i]) * r_x2[i];
}
*y = *(VecType*)r_y;
}
}
template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
......@@ -109,42 +54,19 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
} // namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
if (0 == d % 8 && d <= 16384) { \
if (d <= 512) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 2> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 1024) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 128, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 2048) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 4096) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 512, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else { \
vllm::act_and_mul_kernel_vectorize2<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 1024, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
} else { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
// Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
});
void silu_and_mul(torch::Tensor& out, // [..., d]
......@@ -225,4 +147,4 @@ void gelu_fast(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}
}
\ No newline at end of file
This diff is collapsed.
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#include "dispatch_utils.h"
#include "reduction_utils.cuh"
#ifndef USE_ROCM
......@@ -291,168 +288,22 @@ fused_add_rms_norm_kernel(
} // namespace vllm
template <typename T,int reducesize=C10_WARP_SIZE>
__inline__ __device__ T WarpReduceSum_NEW(T val) {
#pragma unroll
for (int offset = reducesize/2; offset > 0; offset >>= 1) {
val += WARP_SHFL_DOWN(val, offset);
}
return val;
}
template <typename T,int block_size=512>
__inline__ __device__ T BlockReduceSum_NEW(T val, T* shared) {
constexpr int share_size=block_size/C10_WARP_SIZE;
val = WarpReduceSum_NEW<T>(val);
if constexpr(block_size==C10_WARP_SIZE)
{
return val;
}
else{
const int lid = threadIdx.x % C10_WARP_SIZE;
const int wid = threadIdx.x / C10_WARP_SIZE;
if (lid == 0&&wid<share_size) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0&&lid<share_size) {
val = WarpReduceSum_NEW<T,share_size>(shared[lid]);
}
return val;
}
}
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,scalar_t* gamma,int cols,T_ACC eps)
{
constexpr int share_size=block_size/C10_WARP_SIZE;
__shared__ T_ACC val_shared[share_size];
__shared__ T_ACC s_rstd;
T_ACC val=0;
int i=blockIdx.x;
int j=threadIdx.x;
int tcol=cols/Vec;
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec];
scalar_t residual_vec[Vec];
T_ACC trstd;
int idx = i * tcol + j;
idx*=Vec;
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
*(LoadT*)residual_vec = *(LoadT*)(residual+idx);
if (j < tcol) {
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
residual_vec[ii]+=intput_vec[ii];
val += static_cast<T_ACC>(residual_vec[ii])*static_cast<T_ACC>(residual_vec[ii]);
}
}
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
__syncthreads();
trstd=s_rstd;
if (j < tcol) {
#pragma unroll
for(int ii=0;ii<Vec;ii++){
int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(residual_vec[ii]) *trstd* static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(residual+idx)=*(LoadT*)residual_vec;
*(LoadT*)(input+idx)=*(LoadT*)intput_vec;
}
}
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_rms_kernel_eval(scalar_t* input,scalar_t* output,scalar_t* gamma,int cols,T_ACC eps)
{
constexpr int share_size=block_size/C10_WARP_SIZE;
__shared__ T_ACC val_shared[share_size];
__shared__ T_ACC s_rstd;
T_ACC val=0;
int i=blockIdx.x;
int j=threadIdx.x;
int tcol=cols/Vec;
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec];
T_ACC trstd;
int idx = i * tcol + j;
idx*=Vec;
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
if (j < tcol) {
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
val += static_cast<T_ACC>(intput_vec[ii])*static_cast<T_ACC>(intput_vec[ii]);
}
}
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
__syncthreads();
trstd=s_rstd;
if (j < tcol) {
#pragma unroll
for(int ii=0;ii<Vec;ii++){
int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(intput_vec[ii]) *trstd* static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(output+idx)=*(LoadT*)intput_vec;
}
}
void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =inp_ptr % 16 == 0 && wt_ptr % 16 == 0;
if(hidden_size%16==0&&hidden_size<=16384&&ptrs_are_aligned){
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"fused_add_rms_norm_kernel",
[&] {
using T_ACC = at::acc_type<scalar_t, true>;
T_ACC eps = epsilon;
scalar_t* self_data = input.data_ptr<scalar_t>();
scalar_t* out_data =out.data_ptr<scalar_t>();
scalar_t* weight_data=weight.data_ptr<scalar_t>();
if (hidden_size<=1024){
fused_rms_kernel_eval<scalar_t,T_ACC,8,128><<<num_tokens, 128, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=2048){
fused_rms_kernel_eval<scalar_t,T_ACC,8,256><<<num_tokens, 256, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=4096){
if(num_tokens>1200){
fused_rms_kernel_eval<scalar_t,T_ACC,8,512><<<num_tokens, 512, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else{
fused_rms_kernel_eval<scalar_t,T_ACC,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
}
else if(hidden_size<=8192){
fused_rms_kernel_eval<scalar_t,T_ACC,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else{
fused_rms_kernel_eval<scalar_t,T_ACC,16,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
});
}
else{
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
});
}
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
});
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
......@@ -465,74 +316,37 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
num_tokens, hidden_size); \
});
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if(hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192&&ptrs_are_aligned){
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"fused_add_rms_norm_kernel",
[&] {
using T_ACC = at::acc_type<scalar_t, true>;
T_ACC eps = epsilon;
scalar_t* self_data = input.data_ptr<scalar_t>();
scalar_t* other_data =residual.data_ptr<scalar_t>();
scalar_t* weight_data=weight.data_ptr<scalar_t>();
if (hidden_size<=1024){
fused_add_rms_kernel_eval<scalar_t,T_ACC,8,128><<<num_tokens, 128, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=2048){
fused_add_rms_kernel_eval<scalar_t,T_ACC,8,256><<<num_tokens, 256, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=4096){
if(num_tokens>1200){
fused_add_rms_kernel_eval<scalar_t,T_ACC,8,512><<<num_tokens, 512, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else{
fused_add_rms_kernel_eval<scalar_t,T_ACC,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
}
else if(hidden_size<=8192){
fused_add_rms_kernel_eval<scalar_t,T_ACC,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else{
fused_add_rms_kernel_eval<scalar_t,T_ACC,16,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
});
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);
}
else{
dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);
}
}
}
}
\ No newline at end of file
......@@ -23,12 +23,39 @@ void paged_attention_v2(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v1_opt(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2_opt(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon);
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon);
void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
......@@ -56,6 +83,14 @@ void gelu_new(torch::Tensor& out, torch::Tensor& input);
void gelu_fast(torch::Tensor& out, torch::Tensor& input);
void gelu_tanh_and_mul_opt(torch::Tensor& out, torch::Tensor& input);
void gelu_new_opt(torch::Tensor& out, torch::Tensor& input);
void gelu_fast_opt(torch::Tensor& out, torch::Tensor& input);
void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col);
#ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,
......@@ -119,8 +154,6 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& scale);
......
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <cmath>
#include "cuda_compat.h"
#include "dispatch_utils.h"
namespace vllm {
// Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = ACT_FN(x) * y;
}
}
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC>
__global__ void act_and_mul_kernel_vectorize1(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
const int64_t token_idx= blockIdx.x;
int idx = threadIdx.x * VEC;
if (idx < d) {
const int64_t x_index = token_idx * 2 * d + idx;
const int64_t y_index = token_idx * d + idx;
VecType* x1 = (VecType*)(input + x_index);
VecType* x2 = (VecType*)(input + x_index + d);
VecType* y = (VecType*)(out + y_index);
scalar_t r_x1[VEC];
scalar_t r_x2[VEC];
scalar_t r_y[VEC];
*(VecType*)r_x1 = *x1;
*(VecType*)r_x2 = *x2;
#pragma unroll
for (int i = 0; i < VEC; i++) {
r_y[i] = ACT_FN(r_x1[i]) * r_x2[i];
}
*y = *(VecType*)r_y;
}
}
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC>
__global__ void act_and_mul_kernel_vectorize2(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
const int64_t token_idx = blockIdx.x;
int idx = threadIdx.x * VEC;
for (; idx < d; idx += blockDim.x * VEC) {
const int64_t x_index = token_idx * 2 * d + idx;
const int64_t y_index = token_idx * d + idx;
VecType* x1 = (VecType*)(input + x_index);
VecType* x2 = (VecType*)(input + x_index + d);
VecType* y = (VecType*)(out + y_index);
scalar_t r_x1[VEC];
scalar_t r_x2[VEC];
scalar_t r_y[VEC];
*(VecType*)r_x1 = *x1;
*(VecType*)r_x2 = *x2;
#pragma unroll
for (int i = 0; i < VEC; i++) {
r_y[i] = ACT_FN(r_x1[i]) * r_x2[i];
}
*y = *(VecType*)r_y;
}
}
template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
return (T)(((float)x) / (1.0f + expf((float)-x)));
}
template <typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'none' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
const float f = (float)x;
constexpr float ALPHA = M_SQRT1_2;
return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
}
template <typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'tanh' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
const float f = (float)x;
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
constexpr float KAPPA = 0.044715;
float x_cube = f * f * f;
float inner = BETA * (f + KAPPA * x_cube);
return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
}
} // namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
if (0 == d % 8 && d <= 16384) { \
if (d <= 512) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 2> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 1024) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 128, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 2048) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 4096) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 512, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else { \
vllm::act_and_mul_kernel_vectorize2<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 1024, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
} else { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
});
void silu_and_mul_opt(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
void gelu_and_mul_opt(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
}
void gelu_tanh_and_mul_opt(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
}
namespace vllm {
// Element-wise activation kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
}
} // namespace vllm
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \
int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
});
namespace vllm {
template <typename T>
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
const float x3 = (float)(x * x * x);
const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
return ((T)0.5) * x * (((T)1.0) + t);
}
template <typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
const float f = (float)x;
const T t =
(T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
return ((T)0.5) * x * (((T)1.0) + t);
}
} // namespace vllm
void gelu_new(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}
void gelu_fast(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}
This diff is collapsed.
This diff is collapsed.
......@@ -47,6 +47,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops.def(
"paged_attention_v1_opt("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1_opt", torch::kCUDA, &paged_attention_v1_opt);
// PagedAttention V2 (opt).
ops.def(
"paged_attention_v2_opt("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2_opt", torch::kCUDA, &paged_attention_v2_opt);
// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
......@@ -68,6 +96,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
// Activation function used in SwiGLU. (opt)
ops.def("silu_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul_opt", torch::kCUDA, &silu_and_mul_opt);
// Activation function used in GeGLU with `none` approximation. (opt)
ops.def("gelu_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul_opt", torch::kCUDA, &gelu_and_mul_opt);
// Activation function used in GeGLU with `tanh` approximation. (opt)
ops.def("gelu_tanh_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul_opt", torch::kCUDA, &gelu_tanh_and_mul_opt);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
......@@ -81,6 +121,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
// Apply Root Mean Square (RMS) Normalization to the input tensor. (opt)
ops.def(
"rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()");
ops.impl("rms_norm_opt", torch::kCUDA, &rms_norm_opt);
// In-place fused Add and RMS Normalization. (opt)
ops.def(
"fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm_opt", torch::kCUDA, &fused_add_rms_norm_opt);
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
......@@ -108,6 +160,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache_offsets) -> ()");
ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
// trans w16
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);
// Quantization ops
#ifndef USE_ROCM
// Quantized GEMM for AQLM.
......@@ -159,10 +215,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// trans w16
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);
// Quantized GEMM for SqueezeLLM.
ops.def(
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
......
......@@ -38,6 +38,18 @@ def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul(out, x)
def silu_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.silu_and_mul_opt(out, x)
def gelu_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul_opt(out, x)
def gelu_tanh_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul_opt(out, x)
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
......@@ -107,6 +119,65 @@ def paged_attention_v2(
blocksparse_block_size, blocksparse_head_sliding_step)
# page attention ops (opt)
def paged_attention_v1_opt(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v1_opt(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_v2_opt(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v2_opt(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, kv_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
# pos encoding ops
def rotary_embedding(
positions: torch.Tensor,
......@@ -139,6 +210,21 @@ def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
# layer norm ops (opt)
def rms_norm_opt(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
torch.ops._C.rms_norm_opt(out, input, weight, epsilon)
def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon)
# trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
row:int, col:int) -> None :
torch.ops._C.trans_w16_gemm(dst,src,row,col)
# quantization ops
......@@ -206,11 +292,7 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None:
quant_ops.gptq_shuffle(q_weight, q_perm, bit)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
row:int, col:int) -> None :
torch.ops._C.trans_w16_gemm(dst,src,row,col)
# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
......
......@@ -340,7 +340,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# prompt, and they have the same length.
if self.use_triton_flash_attn:
if self.use_flash_attn_auto:
if prefill_meta.max_prefill_seq_len >= 8000:
if prefill_meta.max_prefill_seq_len >= 4096:
out = self.attn_func_triton(
q=query,
k=key,
......
......@@ -5,6 +5,7 @@ import torch
from vllm import _custom_ops as ops
from vllm.attention.ops.prefix_prefill import context_attention_fwd
import vllm.envs as envs
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
......@@ -122,26 +123,48 @@ class PagedAttention:
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
if envs.USE_VLLM_OPT_OP:
ops.paged_attention_v1_opt(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
else:
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
......@@ -156,7 +179,8 @@ class PagedAttention:
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
if envs.USE_VLLM_OPT_OP:
ops.paged_attention_v2_opt(
output,
exp_sums,
max_logits,
......@@ -179,6 +203,30 @@ class PagedAttention:
blocksparse_block_size,
blocksparse_head_sliding_step,
)
else:
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
return output
@staticmethod
......
......@@ -10,6 +10,7 @@ if TYPE_CHECKING:
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_FLASH_ATTN_AUTO: bool = False
USE_VLLM_OPT_OP: bool = False
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
......@@ -138,6 +139,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_ATTN_AUTO":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "True").lower() in
("true", "1")),
# flag to control vllm to use optimized kernels
"USE_VLLM_OPT_OP":
lambda: (os.environ.get("USE_VLLM_OPT_OP", "True").lower() in
("true", "1")),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
......
......@@ -11,6 +11,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
import vllm.envs as envs
class SiluAndMul(CustomOp):
......@@ -34,7 +35,10 @@ class SiluAndMul(CustomOp):
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
if envs.USE_VLLM_OPT_OP:
ops.silu_and_mul(out, x)
else:
ops.silu_and_mul(out, x)
return out
......@@ -66,9 +70,15 @@ class GeluAndMul(CustomOp):
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none":
ops.gelu_and_mul(out, x)
if envs.USE_VLLM_OPT_OP:
ops.gelu_and_mul_opt(out, x)
else:
ops.gelu_and_mul(out, x)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
if envs.USE_VLLM_OPT_OP:
ops.gelu_tanh_and_mul_opt(out, x)
else:
ops.gelu_tanh_and_mul(out, x)
return out
def extra_repr(self) -> str:
......
......@@ -5,6 +5,7 @@ import torch
import torch.nn as nn
from vllm.model_executor.custom_op import CustomOp
import vllm.envs as envs
class RMSNorm(CustomOp):
......@@ -51,20 +52,36 @@ class RMSNorm(CustomOp):
from vllm import _custom_ops as ops
if residual is not None:
ops.fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
if envs.USE_VLLM_OPT_OP:
ops.fused_add_rms_norm_opt(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
else:
ops.fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
ops.rms_norm(
if envs.USE_VLLM_OPT_OP:
ops.rms_norm_opt(
out,
x,
self.weight.data,
self.variance_epsilon,
)
else:
ops.rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out
def extra_repr(self) -> str:
......
......@@ -22,7 +22,8 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM']
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM']
use_triton_fa_architectures = ['DeepseekV2ForCausalLM']
if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '1'
......@@ -35,6 +36,10 @@ def get_model_architecture(
os.environ['GEMM_PAD'] = '0'
os.environ['FA_PAD'] = '0'
if any(arch in architectures for arch in use_triton_fa_architectures):
os.environ['VLLM_USE_TRITON_FLASH_ATTN'] = '1'
os.environ['VLLM_USE_FLASH_ATTN_AUTO'] = '0'
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment