Commit b4cf96af authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'vllm-0.6.2-zhagnshao' into 'v0.6.2-dev'

增加bw pa tc优化

See merge request dcutoolkit/deeplearing/vllm!54
parents ec0136e7 8af5263f
......@@ -6,26 +6,19 @@
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define WARP_SIZE 64
#include "static_switch_tc.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
inline std::string get_device_name()
std::string get_device_name()
{
hipDeviceProp_t props{};
int device;
......@@ -43,6 +36,9 @@ inline std::string get_device_name()
const std::string raw_name(props.gcnArchName);
return raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
}
static const std::string device_name=get_device_name();
static inline int get_env_(const char *env_var) {
if (char *value = std::getenv(env_var)) {
return atoi(value);
......@@ -170,8 +166,8 @@ __device__ void paged_attention_kernel_TC(
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.z;
const int partition_idx = blockIdx.y;
const int max_num_partitions = gridDim.y;
const int partition_idx = blockIdx.x;
const int max_num_partitions = gridDim.x;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]);
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
......@@ -203,14 +199,10 @@ __device__ void paged_attention_kernel_TC(
const int num_queries_per_kv = num_heads / num_kv_heads;
const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
const int odd_tg_round = (((blockIdx.z * gridDim.y * gridDim.x) + blockIdx.y * gridDim.x) / 128) % 2;
const int mid_x = gridDim.x / 2;
const int blockIdx_shift = (odd_tg_round | (gridDim.x & 1)) ? blockIdx.x : (blockIdx.x < mid_x ? (blockIdx.x + mid_x) : (blockIdx.x - mid_x));
const int head_idx = (blockIdx_shift / num_blocks_per_kv) * num_queries_per_kv + (blockIdx_shift % num_blocks_per_kv) * REUSE_KV_TIMES;
//const int head_idx=(blockIdx.x / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES;
const int head_idx=(blockIdx.y / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.y % num_blocks_per_kv) * REUSE_KV_TIMES;
int q_boundary=REUSE_KV_TIMES;
if(num_heads < REUSE_KV_TIMES*gridDim.x && (num_blocks_per_kv-1)*REUSE_KV_TIMES == head_idx%num_queries_per_kv)
if(num_heads < REUSE_KV_TIMES*gridDim.y && (num_blocks_per_kv-1)*REUSE_KV_TIMES == head_idx%num_queries_per_kv)
q_boundary=num_queries_per_kv-(num_blocks_per_kv-1)*REUSE_KV_TIMES;
const int kv_head_idx = head_idx / num_queries_per_kv;
constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1;
......@@ -233,7 +225,7 @@ __device__ void paged_attention_kernel_TC(
q_vec.data[1]={0,0,0,0};
__shared__ half4x2 q_vecs[REUSE_KV_TIMES][16];
//if(thread_idx==0)printf("blockIdx.x==%d,q_boundary=%d,head_idx=%d,kv_head_idx=%d\n",blockIdx.x,q_boundary,head_idx,kv_head_idx);
//if(thread_idx==0)printf("blockIdx.y==%d,q_boundary=%d,head_idx=%d,kv_head_idx=%d\n",blockIdx.y,q_boundary,head_idx,kv_head_idx);
for(int i=0;i<q_boundary;i++){
if(thread_idx<16){
q_vecs[i][thread_idx]=*reinterpret_cast<const half4x2*>(q_ptr+i*HEAD_SIZE+thread_idx*8);
......@@ -303,7 +295,7 @@ __device__ void paged_attention_kernel_TC(
}
}
}
// if(blockIdx.x==0)printf("%d,qkmax=%f\n",threadIdx.x,qk_max[0]);
// if(blockIdx.y==0)printf("%d,qkmax=%f\n",threadIdx.x,qk_max[0]);
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
......@@ -353,7 +345,6 @@ __device__ void paged_attention_kernel_TC(
*exp_sums_ptr = exp_sum;
}
}
constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, WARP_SIZE);//2
if constexpr(REUSE_KV_TIMES<=2){
float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
......@@ -441,6 +432,7 @@ __device__ void paged_attention_kernel_TC(
}
}
}
#if defined __gfx928__
else{
constexpr int GROUPS=reuse_group*4;
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
......@@ -533,6 +525,99 @@ __device__ void paged_attention_kernel_TC(
}
}
}
#else
else{
constexpr int GROUPS=reuse_group*4;
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float4_t accs[4][NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll
for(int k=0;k<4;k++)
{
accs[k][i] = {0.f,0.f,0.f,0.f};
}
}
scalar_t zero_value;
zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){
logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll
for(int k=0;k<4;k++){
int offset=i*1024+k*256;
half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < 4; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
builtin_amdgcn_mmac<is_half,use_vmac>(v_vec,logits_vec,accs[k][i]);
}
}
}
if constexpr (NUM_THREADS>64){
__syncthreads();
using floatV_t = __attribute__( (__vector_size__(reuse_group * sizeof(float)) )) float;
// Perform reduction across warps.
for(int m=0; m<4; m++) {
floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
for(int k=0;k<NUM_ROWS_PER_THREAD;k++){
out_smem[((warp_idx - mid) * 64+lane)*NUM_ROWS_PER_THREAD+k]=*(floatV_t*)(&(accs[m][k]));
}
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
for(int k=0;k<NUM_ROWS_PER_THREAD;k++){
floatV_t tmp=out_smem[thread_idx*NUM_ROWS_PER_THREAD+k];
#pragma unroll
for (int i = 0; i < reuse_group; i++) {
accs[m][k][i] += tmp[i];
}
}
}
__syncthreads();
}
}
}
if (warp_idx == 0) {
for(int g=0;g<reuse_group;g++){
int reusekvid=g*4+rows;
if(reusekvid<q_boundary){
scalar_t* out_ptr =
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
(head_idx+reusekvid) * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
for(int k=0;k<4;k++){
const int row_idx = rowid+16*k + i * WARP_SIZE;
from_float(*(out_ptr + row_idx), accs[k][i][g]);
}
}
}
}
}
}
#endif
}
......@@ -736,6 +821,35 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
void get_numberthread_and_reuse_kv_v1(int& num_thread,int& reusekv,int batchsize,int seq,int qheads,int kvheads){
//mha
reusekv=1;
num_thread=256;
if(device_name=="gfx936"){//bw
if(qheads==kvheads){
if(seq<16){num_thread=64;return;}
if(batchsize>=32&&seq>=1000)return;
if(batchsize*qheads>=512)num_thread=64;
return;
}
if(seq<=16){
num_thread=64;
if(qheads*batchsize>1000)reusekv=4;
return;
}
if(seq<=64){
if(qheads*batchsize>1000)reusekv=4;
return;
}
if(seq<=200){
if(qheads*batchsize>400)reusekv=4;
return;
}
if(seq<=500){
if(qheads*batchsize>200)reusekv=4;
return;
}
if(((qheads-1)/16+1)*batchsize>=64&&qheads/kvheads>4&&seq<7800)reusekv=8;
else if(qheads*batchsize>100)reusekv=4;
return;
}
if(qheads==kvheads){
//llama 7B ,其他模型未可知
if(seq<=16||batchsize>=32)num_thread=64;
......@@ -844,7 +958,7 @@ void paged_attention_v1_launcher_opt_tc(
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
if (NUM_WARPS==64)outputs_size=0;
int shared_mem_size = ::max(logits_size, outputs_size);
dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1,num_seqs);
dim3 grid(1,(num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads,num_seqs);
dim3 block(NUM_THREADS);
if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n",
reusekv,num_thread,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs);
......@@ -927,7 +1041,7 @@ void paged_attention_v1_opt_tc(
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(get_device_name()!="gfx928" && get_device_name()!="gfx936")){
block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v1_opt(out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
......@@ -958,8 +1072,32 @@ void paged_attention_v1_opt_tc(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int batchsize,int max_num_partitions,int qheads,int kvheads){
void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int batchsize,int max_num_partitions,int qheads,int kvheads,int num_blocks){
reusekv=1;
num_thread=256;
if(device_name=="gfx936"){//bw
if(max_num_partitions==16&&num_blocks==1024){//ali test
if(batchsize==1&&qheads==16&&kvheads==16){num_thread=128;return;}
if(batchsize==1&&qheads==32&&kvheads==32){num_thread=64;return;}
if(batchsize==1){
if(qheads==52){reusekv=8;return;}
if(qheads==13){reusekv=2;return;}
reusekv=4;return;
}
if(batchsize==64){
if(qheads==13||qheads==32){num_thread=128;reusekv=8;return;}
if(qheads==52){reusekv=16;return;}
reusekv=8;return;
}
}
if(qheads==kvheads)return;
int bp=max_num_partitions*batchsize;
if(qheads/kvheads>4){
if(qheads==16&&bp>96||qheads<16&&bp>=192||qheads>16&&bp>24){reusekv=8;return;}
}
if(qheads/4*bp>=32)reusekv=4;
return;
}
int blocks=batchsize*qheads*max_num_partitions;
if(qheads==kvheads){
if(blocks<=80||blocks>8000){num_thread=256;}
......@@ -1009,6 +1147,7 @@ void paged_attention_v2_launcher_opt_tc(
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int num_blocks=key_cache.size(0);
// printf("paged_attention_v2\n");
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
......@@ -1036,11 +1175,10 @@ void paged_attention_v2_launcher_opt_tc(
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){
//if(head_size==128&&get_device_name()=="gfx928"){
constexpr int HEAD_SIZE=128;
constexpr static int use_vmac = false;
int reusekv, num_thread;
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,num_seqs,max_num_partitions,num_heads,num_kv_heads);
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,num_seqs,max_num_partitions,num_heads,num_kv_heads,num_blocks);
if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES;
if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
REUSEKV_SWITCH(reusekv,[&] {
......@@ -1049,8 +1187,8 @@ void paged_attention_v2_launcher_opt_tc(
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 2;
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
dim3 grid;
grid.x = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads;
grid.y = max_num_partitions;
grid.y = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads;
grid.x = max_num_partitions;
grid.z = num_seqs;
dim3 block(NUM_THREADS);
int shared_mem_size = ::max(logits_size, outputs_size);
......@@ -1060,7 +1198,6 @@ void paged_attention_v2_launcher_opt_tc(
});
});
}
//}
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
......@@ -1145,7 +1282,7 @@ void paged_attention_v2_opt_tc(
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(get_device_name()!="gfx928" && get_device_name()!="gfx936")){
block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v2_opt(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment