Commit bdf0140b authored by zhanghj2's avatar zhanghj2
Browse files

使用buffer load lds读取q, 优化了vgpr溢出

parent 515dbd44
...@@ -30,40 +30,11 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -30,40 +30,11 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
using index_t = int64_t; using index_t = int64_t;
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
const int lane_idx = tidx % 64; const int lane_idx = tidx % 64;
const int warp_idx = tidx / 64; const int warp_idx = __builtin_amdgcn_readfirstlane(tidx / 64);
const int head_block_idx = NUM_M_BLOCKS == 1 ? 0 : blockIdx.x; const int head_block_idx = NUM_M_BLOCKS == 1 ? 0 : blockIdx.x;
const int s_q_idx = blockIdx.y; const int s_q_idx = blockIdx.y;
extern __shared__ char shared_memory[]; extern __shared__ char shared_memory[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(shared_memory); SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(shared_memory);
struct MainloopArgs {
int start_block_idx, end_block_idx;
bool is_no_split;
// The following fields are only valid for MODEL1
int topk_length, extra_topk_length, num_orig_kv_blocks;
};
auto get_cur_req_info = [&](int batch_idx) -> MainloopArgs {
MainloopArgs args;
int total_topk_padded;
if constexpr (MODEL_TYPE == ModelType::V32) {
total_topk_padded = params.topk;
} else {
int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk;
int orig_topk_padded = max(ku::ceil(topk_length, (int)TOPK_BLOCK_SIZE), (int)TOPK_BLOCK_SIZE);
int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk;
total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)TOPK_BLOCK_SIZE);
args.topk_length = topk_length;
args.extra_topk_length = extra_topk_length;
args.num_orig_kv_blocks = orig_topk_padded / TOPK_BLOCK_SIZE;
}
args.start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;
args.end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / TOPK_BLOCK_SIZE;
args.is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true);
return args;
};
const index_t row_offset_q = batch_idx * params.stride_q_b + head_block_idx * BLOCK_M * params.stride_q_h_q + s_q_idx * params.stride_q_s_q; const index_t row_offset_q = batch_idx * params.stride_q_b + head_block_idx * BLOCK_M * params.stride_q_h_q + s_q_idx * params.stride_q_s_q;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q) + row_offset_q), Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q) + row_offset_q),
...@@ -90,7 +61,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -90,7 +61,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
auto thr_mma_16x16x32 = tiled_mma_16x16x32.get_thread_slice(tidx); auto thr_mma_16x16x32 = tiled_mma_16x16x32.get_thread_slice(tidx);
TiledMMA tiled_mma_o = TiledMma_O{}; TiledMMA tiled_mma_o = TiledMma_O{};
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
#if 0
// load Q // load Q
auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma); auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
...@@ -101,6 +72,196 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -101,6 +72,196 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tSgQ))); Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tSgQ)));
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ, params.h_q - head_block_idx * BLOCK_M); flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ, params.h_q - head_block_idx * BLOCK_M);
__syncthreads(); __syncthreads();
#else
Tensor tSrQ = thr_mma.partition_fragment_A(gQ);
// 需要的最大空间为 16 * 576 * 2
Element* s_q = reinterpret_cast<Element *>(shared_memory);
auto lds_direct_copy_q = [&](const int k_idx, const int offset_k) {
// static_assert(offset_k == 0 || offset_k == 1);
// static_assert(k_idx < 3);
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(gQ.data().get());
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = 64 * 8 * 2;
constexpr int bytes_per_block = bytes_per_warp * 4;
const int row_idx = lane_idx % 16;
const int col_idx = lane_idx / 16;
const int row_offset = row_idx;
if constexpr (MODEL_TYPE == ModelType::V32)
{
int col_offset;
if (k_idx == 2)
{
col_offset = k_idx * 256 + warp_idx * 8 + col_idx * 16;
}
else
{
col_offset = k_idx * 256 + warp_idx * 64 + col_idx * 16 + offset_k * 8;
}
int offset_v = (row_offset * params.stride_q_h_q + col_offset) * 2;
if (head_block_idx * BLOCK_M + row_idx >= params.h_q) {
offset_v = -1;
}
if (k_idx == 2 && warp_idx >= 2)
{
offset_v = -1;
}
const int offset_s = 0;
int ldsAddrPerWave = reinterpret_cast<size_t>(s_q) + warp_idx * bytes_per_warp + k_idx * bytes_per_block
+ offset_k * 3 * bytes_per_block;
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
else
{
const int col_offset = k_idx * 256 + warp_idx * 64 + col_idx * 16 + offset_k * 8;
int offset_v = (row_offset * params.stride_q_h_q + col_offset) * 2;
if (head_block_idx * BLOCK_M + row_idx >= params.h_q) {
offset_v = -1;
}
const int offset_s = 0;
int ldsAddrPerWave = reinterpret_cast<size_t>(s_q) + warp_idx * bytes_per_warp + k_idx * bytes_per_block
+ offset_k * 2 * bytes_per_block;
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
};
if constexpr (MODEL_TYPE == ModelType::V32)
{
// __builtin_amdgcn_sched_barrier(0);
lds_direct_copy_q(0, 0);
lds_direct_copy_q(1, 0);
lds_direct_copy_q(0, 1);
lds_direct_copy_q(1, 1);
lds_direct_copy_q(2, 0);
Element* s_q_read_ptr = s_q + lane_idx * 8;
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");
for (int k = 0; k < 4; k++)
{
for (int i = 0; i < 8; i++)
{
tSrQ(i, 0, k) = s_q_read_ptr[i];
}
s_q_read_ptr += 16 * 32;
}
asm volatile("s_waitcnt vmcnt(3) \n s_barrier");
for (int k = 4; k < 8; k++)
{
for (int i = 0; i < 8; i++)
{
tSrQ(i, 0, k) = s_q_read_ptr[i];
}
s_q_read_ptr += 16 * 32;
}
asm volatile("s_waitcnt vmcnt(2) \n s_barrier");
s_q_read_ptr = s_q + lane_idx * 8 + 3 * 4 * 16 * 4 * 8;
for (int k = 0; k < 4; k++)
{
for (int i = 8; i < 16; i++)
{
tSrQ(i, 0, k) = s_q_read_ptr[i - 8];
}
s_q_read_ptr += 16 * 32;
}
asm volatile("s_waitcnt vmcnt(1) \n s_barrier");
for (int k = 4; k < 8; k++)
{
for (int i = 8; i < 16; i++)
{
tSrQ(i, 0, k) = s_q_read_ptr[i - 8];
}
s_q_read_ptr += 16 * 32;
}
asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
s_q_read_ptr = s_q + lane_idx * 8 + 2 * 4 * 16 * 4 * 8;
for (int k = 8; k < 9; k++)
{
for (int i = 0; i < 8; i++)
{
tSrQ(i, 0, k) = s_q_read_ptr[i];
}
s_q_read_ptr += 16 * 32;
}
for (int k = 8; k < 9; k++)
{
for (int i = 8; i < 16; i++)
{
tSrQ(i, 0, k) = s_q_read_ptr[i-8];
}
s_q_read_ptr += 16 * 32;
}
// __syncthreads();
asm volatile("s_waitcnt lgkmcnt(0) \n s_barrier");
// __builtin_amdgcn_sched_barrier(0);
}
else
{
// __builtin_amdgcn_sched_barrier(0);
lds_direct_copy_q(0, 0);
lds_direct_copy_q(1, 0);
lds_direct_copy_q(0, 1);
lds_direct_copy_q(1, 1);
Element* s_q_read_ptr = s_q + lane_idx * 8;
asm volatile("s_waitcnt vmcnt(3) \n s_barrier");
for (int k = 0; k < 4; k++)
{
for (int i = 0; i < 8; i++)
{
tSrQ(i, 0, k) = s_q_read_ptr[i];
}
s_q_read_ptr += 16 * 32;
}
asm volatile("s_waitcnt vmcnt(2) \n s_barrier");
for (int k = 4; k < 8; k++)
{
for (int i = 0; i < 8; i++)
{
tSrQ(i, 0, k) = s_q_read_ptr[i];
}
s_q_read_ptr += 16 * 32;
}
asm volatile("s_waitcnt vmcnt(1) \n s_barrier");
for (int k = 0; k < 4; k++)
{
for (int i = 8; i < 16; i++)
{
tSrQ(i, 0, k) = s_q_read_ptr[i - 8];
}
s_q_read_ptr += 16 * 32;
}
asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
for (int k = 4; k < 8; k++)
{
for (int i = 8; i < 16; i++)
{
tSrQ(i, 0, k) = s_q_read_ptr[i - 8];
}
s_q_read_ptr += 16 * 32;
}
asm volatile("s_waitcnt lgkmcnt(0) \n s_barrier");
// __builtin_amdgcn_sched_barrier(0);
}
#endif
// zhj debug // zhj debug
// if (head_block_idx == 0) // if (head_block_idx == 0)
// { // {
...@@ -133,7 +294,35 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -133,7 +294,35 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
uint32x4_t data_128; uint32x4_t data_128;
uint16_t data_array[8]; uint16_t data_array[8];
}; };
struct MainloopArgs {
int start_block_idx, end_block_idx;
bool is_no_split;
// The following fields are only valid for MODEL1
int topk_length, extra_topk_length, num_orig_kv_blocks;
};
auto get_cur_req_info = [&](int batch_idx) -> MainloopArgs {
MainloopArgs args;
int total_topk_padded;
if constexpr (MODEL_TYPE == ModelType::V32) {
total_topk_padded = params.topk;
} else {
int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk;
int orig_topk_padded = max(ku::ceil(topk_length, (int)TOPK_BLOCK_SIZE), (int)TOPK_BLOCK_SIZE);
int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk;
total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)TOPK_BLOCK_SIZE);
args.topk_length = topk_length;
args.extra_topk_length = extra_topk_length;
args.num_orig_kv_blocks = orig_topk_padded / TOPK_BLOCK_SIZE;
}
args.start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;
args.end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / TOPK_BLOCK_SIZE;
args.is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true);
return args;
};
Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{}); Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{});
clear(acc_o); clear(acc_o);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment