"vscode:/vscode.git/clone" did not exist on "756b30a5f30ee08b97243e1077419d8d74442b02"
Commit bd93e661 authored by zhuwenwen's avatar zhuwenwen
Browse files

Update refactoring operation

parent 4405f82c
...@@ -157,10 +157,10 @@ set(VLLM_EXT_SRC ...@@ -157,10 +157,10 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu" "csrc/layernorm_kernels.cu"
"csrc/opt/transpose_kernels.cu" "csrc/opt/transpose_kernels.cu"
"csrc/opt/activation_kernels_opt.cu" "csrc/opt/activation_kernels_opt.cu"
"csrc/opt/attention_kernels_opt.cu" "csrc/attention/attention_kernels_opt.cu"
"csrc/opt/layernorm_kernels_opt.cu" "csrc/opt/layernorm_kernels_opt.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu" #"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
# "csrc/quantization/fp8/common.cu" # "csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu" "csrc/cuda_utils_kernels.cu"
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <algorithm> #include <algorithm>
#include "../attention/attention_dtypes.h" #include "attention/attention_dtypes.h"
#include "../attention/attention_utils.cuh" #include "attention/attention_utils.cuh"
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
...@@ -70,7 +70,7 @@ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, ...@@ -70,7 +70,7 @@ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, int REUSE_KV_TIMES = 1, bool odd_nheads = false, bool IS_BLOCK_SPARSE, int REUSE_KV_TIMES = 1, bool odd_nheads = false,
int PARTITION_SIZE = 0> // Zero means no partitioning. int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel( __device__ void paged_attention_kernel_opt(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -590,7 +590,7 @@ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, ...@@ -590,7 +590,7 @@ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int REUSE_KV_TIMES = 1, int REUSE_KV_TIMES = 1,
bool IS_BLOCK_SPARSE, bool IS_BLOCK_SPARSE,
bool odd_nheads = false> bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v1_kernel( __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
...@@ -608,7 +608,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel( ...@@ -608,7 +608,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel(
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads>( KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_heads, num_kv_heads, scale, block_tables, seq_lens, v_cache, num_heads, num_kv_heads, scale, block_tables, seq_lens,
...@@ -625,7 +625,7 @@ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, ...@@ -625,7 +625,7 @@ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int REUSE_KV_TIMES, int REUSE_KV_TIMES,
int PARTITION_SIZE, int PARTITION_SIZE,
bool odd_nheads = false> bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v2_kernel( __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel_opt(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -647,7 +647,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel( ...@@ -647,7 +647,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel(
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, PARTITION_SIZE>( KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, num_kv_heads, scale, exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
...@@ -659,7 +659,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel( ...@@ -659,7 +659,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel(
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS, template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE> int PARTITION_SIZE>
__global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel( __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -767,11 +767,11 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel( ...@@ -767,11 +767,11 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel(
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \ ((void*)vllm::paged_attention_v1_kernel_opt<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \ BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \ KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \
shared_mem_size); \ shared_mem_size); \
hipLaunchKernelGGL(( vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \ hipLaunchKernelGGL(( vllm::paged_attention_v1_kernel_opt<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>) \ NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>) \
, dim3(grid), dim3(block), shared_mem_size, stream, \ , dim3(grid), dim3(block), shared_mem_size, stream, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
...@@ -918,7 +918,7 @@ void paged_attention_v1_opt( ...@@ -918,7 +918,7 @@ void paged_attention_v1_opt(
} }
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
hipLaunchKernelGGL(( vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \ hipLaunchKernelGGL(( vllm::paged_attention_v2_kernel_opt<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \ NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>) \ REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>) \
, dim3(grid), dim3(block), shared_mem_size, stream, \ , dim3(grid), dim3(block), shared_mem_size, stream, \
...@@ -928,7 +928,7 @@ void paged_attention_v1_opt( ...@@ -928,7 +928,7 @@ void paged_attention_v1_opt(
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \ kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \ blocksparse_block_size, blocksparse_head_sliding_step); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \ hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>) \ PARTITION_SIZE>) \
, dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \ , dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
......
...@@ -83,11 +83,11 @@ void gelu_new(torch::Tensor& out, torch::Tensor& input); ...@@ -83,11 +83,11 @@ void gelu_new(torch::Tensor& out, torch::Tensor& input);
void gelu_fast(torch::Tensor& out, torch::Tensor& input); void gelu_fast(torch::Tensor& out, torch::Tensor& input);
void gelu_tanh_and_mul_opt(torch::Tensor& out, torch::Tensor& input); void silu_and_mul_opt(torch::Tensor& out, torch::Tensor& input);
void gelu_new_opt(torch::Tensor& out, torch::Tensor& input); void gelu_and_mul_opt(torch::Tensor& out, torch::Tensor& input);
void gelu_fast_opt(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul_opt(torch::Tensor& out, torch::Tensor& input);
void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col); void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <cmath> #include <cmath>
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h" #include "../dispatch_utils.h"
namespace vllm { namespace vllm {
...@@ -25,7 +25,7 @@ __global__ void act_and_mul_kernel( ...@@ -25,7 +25,7 @@ __global__ void act_and_mul_kernel(
} }
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC> template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC>
__global__ void act_and_mul_kernel_vectorize1( __global__ void act_and_mul_kernel_opt1(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d] const scalar_t* __restrict__ input, // [..., 2, d]
const int d) { const int d) {
...@@ -52,7 +52,7 @@ __global__ void act_and_mul_kernel_vectorize1( ...@@ -52,7 +52,7 @@ __global__ void act_and_mul_kernel_vectorize1(
} }
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC> template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC>
__global__ void act_and_mul_kernel_vectorize2( __global__ void act_and_mul_kernel_opt2(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d] const scalar_t* __restrict__ input, // [..., 2, d]
const int d) { const int d) {
...@@ -120,23 +120,23 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { ...@@ -120,23 +120,23 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
input.scalar_type(), "act_and_mul_kernel", [&] { \ input.scalar_type(), "act_and_mul_kernel", [&] { \
if (0 == d % 8 && d <= 16384) { \ if (0 == d % 8 && d <= 16384) { \
if (d <= 512) { \ if (d <= 512) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 2> \ vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 2> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \ <<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \ input.data_ptr<scalar_t>(), d); \
} else if (d <= 1024) { \ } else if (d <= 1024) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \ vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 128, 0, stream>>>(out.data_ptr<scalar_t>(), \ <<<grid, 128, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \ input.data_ptr<scalar_t>(), d); \
} else if (d <= 2048) { \ } else if (d <= 2048) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \ vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \ <<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \ input.data_ptr<scalar_t>(), d); \
} else if (d <= 4096) { \ } else if (d <= 4096) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \ vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 512, 0, stream>>>(out.data_ptr<scalar_t>(), \ <<<grid, 512, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \ input.data_ptr<scalar_t>(), d); \
} else { \ } else { \
vllm::act_and_mul_kernel_vectorize2<scalar_t, KERNEL<scalar_t>, 8> \ vllm::act_and_mul_kernel_opt2<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 1024, 0, stream>>>(out.data_ptr<scalar_t>(), \ <<<grid, 1024, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \ input.data_ptr<scalar_t>(), d); \
} \ } \
...@@ -165,64 +165,3 @@ void gelu_tanh_and_mul_opt(torch::Tensor& out, // [..., d] ...@@ -165,64 +165,3 @@ void gelu_tanh_and_mul_opt(torch::Tensor& out, // [..., d]
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
} }
namespace vllm {
// Element-wise activation kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
}
} // namespace vllm
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \
int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
});
namespace vllm {
template <typename T>
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
const float x3 = (float)(x * x * x);
const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
return ((T)0.5) * x * (((T)1.0) + t);
}
template <typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
const float f = (float)x;
const T t =
(T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
return ((T)0.5) * x * (((T)1.0) + t);
}
} // namespace vllm
void gelu_new(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}
void gelu_fast(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}
...@@ -323,7 +323,7 @@ __inline__ __device__ T BlockReduceSum_NEW(T val, T* shared) { ...@@ -323,7 +323,7 @@ __inline__ __device__ T BlockReduceSum_NEW(T val, T* shared) {
} }
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512> template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,scalar_t* gamma,int cols,T_ACC eps) __global__ void fused_add_rms_kernel_opt(scalar_t* input,scalar_t* residual,scalar_t* gamma,int cols,T_ACC eps)
{ {
constexpr int share_size=block_size/C10_WARP_SIZE; constexpr int share_size=block_size/C10_WARP_SIZE;
__shared__ T_ACC val_shared[share_size]; __shared__ T_ACC val_shared[share_size];
...@@ -363,7 +363,7 @@ __global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,sca ...@@ -363,7 +363,7 @@ __global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,sca
} }
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512> template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_rms_kernel_eval(scalar_t* input,scalar_t* output,scalar_t* gamma,int cols,T_ACC eps) __global__ void fused_rms_kernel_opt(scalar_t* input,scalar_t* output,scalar_t* gamma,int cols,T_ACC eps)
{ {
constexpr int share_size=block_size/C10_WARP_SIZE; constexpr int share_size=block_size/C10_WARP_SIZE;
__shared__ T_ACC val_shared[share_size]; __shared__ T_ACC val_shared[share_size];
...@@ -422,24 +422,24 @@ void rms_norm_opt(torch::Tensor& out, // [..., hidden_size] ...@@ -422,24 +422,24 @@ void rms_norm_opt(torch::Tensor& out, // [..., hidden_size]
scalar_t* out_data =out.data_ptr<scalar_t>(); scalar_t* out_data =out.data_ptr<scalar_t>();
scalar_t* weight_data=weight.data_ptr<scalar_t>(); scalar_t* weight_data=weight.data_ptr<scalar_t>();
if (hidden_size<=1024){ if (hidden_size<=1024){
fused_rms_kernel_eval<scalar_t,T_ACC,8,128><<<num_tokens, 128, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps); fused_rms_kernel_opt<scalar_t,T_ACC,8,128><<<num_tokens, 128, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
} }
else if(hidden_size<=2048){ else if(hidden_size<=2048){
fused_rms_kernel_eval<scalar_t,T_ACC,8,256><<<num_tokens, 256, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps); fused_rms_kernel_opt<scalar_t,T_ACC,8,256><<<num_tokens, 256, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
} }
else if(hidden_size<=4096){ else if(hidden_size<=4096){
if(num_tokens>1200){ if(num_tokens>1200){
fused_rms_kernel_eval<scalar_t,T_ACC,8,512><<<num_tokens, 512, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps); fused_rms_kernel_opt<scalar_t,T_ACC,8,512><<<num_tokens, 512, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
} }
else{ else{
fused_rms_kernel_eval<scalar_t,T_ACC,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps); fused_rms_kernel_opt<scalar_t,T_ACC,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
} }
} }
else if(hidden_size<=8192){ else if(hidden_size<=8192){
fused_rms_kernel_eval<scalar_t,T_ACC,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps); fused_rms_kernel_opt<scalar_t,T_ACC,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
} }
else{ else{
fused_rms_kernel_eval<scalar_t,T_ACC,16,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps); fused_rms_kernel_opt<scalar_t,T_ACC,16,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
} }
}); });
} }
...@@ -492,24 +492,24 @@ void fused_add_rms_norm_opt(torch::Tensor& input, // [..., hidden_size] ...@@ -492,24 +492,24 @@ void fused_add_rms_norm_opt(torch::Tensor& input, // [..., hidden_size]
scalar_t* other_data =residual.data_ptr<scalar_t>(); scalar_t* other_data =residual.data_ptr<scalar_t>();
scalar_t* weight_data=weight.data_ptr<scalar_t>(); scalar_t* weight_data=weight.data_ptr<scalar_t>();
if (hidden_size<=1024){ if (hidden_size<=1024){
fused_add_rms_kernel_eval<scalar_t,T_ACC,8,128><<<num_tokens, 128, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps); fused_add_rms_kernel_opt<scalar_t,T_ACC,8,128><<<num_tokens, 128, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
} }
else if(hidden_size<=2048){ else if(hidden_size<=2048){
fused_add_rms_kernel_eval<scalar_t,T_ACC,8,256><<<num_tokens, 256, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps); fused_add_rms_kernel_opt<scalar_t,T_ACC,8,256><<<num_tokens, 256, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
} }
else if(hidden_size<=4096){ else if(hidden_size<=4096){
if(num_tokens>1200){ if(num_tokens>1200){
fused_add_rms_kernel_eval<scalar_t,T_ACC,8,512><<<num_tokens, 512, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps); fused_add_rms_kernel_opt<scalar_t,T_ACC,8,512><<<num_tokens, 512, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
} }
else{ else{
fused_add_rms_kernel_eval<scalar_t,T_ACC,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps); fused_add_rms_kernel_opt<scalar_t,T_ACC,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
} }
} }
else if(hidden_size<=8192){ else if(hidden_size<=8192){
fused_add_rms_kernel_eval<scalar_t,T_ACC,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps); fused_add_rms_kernel_opt<scalar_t,T_ACC,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
} }
else{ else{
fused_add_rms_kernel_eval<scalar_t,T_ACC,16,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps); fused_add_rms_kernel_opt<scalar_t,T_ACC,16,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
} }
}); });
} }
......
...@@ -340,7 +340,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -340,7 +340,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# prompt, and they have the same length. # prompt, and they have the same length.
if self.use_triton_flash_attn: if self.use_triton_flash_attn:
if self.use_flash_attn_auto: if self.use_flash_attn_auto:
if prefill_meta.max_prefill_seq_len >= 4096: if prefill_meta.max_prefill_seq_len >= 8000:
out = self.attn_func_triton( out = self.attn_func_triton(
q=query, q=query,
k=key, k=key,
......
...@@ -808,7 +808,10 @@ class ModelRunner: ...@@ -808,7 +808,10 @@ class ModelRunner:
import vllm.envs as envs import vllm.envs as envs
if envs.VLLM_USE_FLASH_ATTN_AUTO: if envs.VLLM_USE_FLASH_ATTN_AUTO:
for group_id in range(1): for group_id in range(1):
if max_num_batched_tokens >= 8000:
seq_len = 8000 seq_len = 8000
else:
seq_len = max_num_batched_tokens
if vlm_config is None: if vlm_config is None:
seq_data = SequenceData([0] * seq_len) seq_data = SequenceData([0] * seq_len)
dummy_multi_modal_data = None dummy_multi_modal_data = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment