Commit 6fb681fc authored by zhanghj2's avatar zhanghj2
Browse files

lambda函数优化代码结构

parent 75f8262c
...@@ -136,10 +136,12 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -136,10 +136,12 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{}); Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{});
clear(acc_o); clear(acc_o);
flash::Softmax<size<1>(acc_o)> softmax; flash::Softmax<size<1>(acc_o)> softmax;
MainloopArgs args = get_cur_req_info(batch_idx);
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; block_idx++) { struct IsOrigBlock {};
struct IsExtraBlock {};
auto process_one_block = [&](int block_idx, auto is_extra_block_t) {
static constexpr bool IS_EXTRA_BLOCK = std::is_same_v<decltype(is_extra_block_t), IsExtraBlock>;
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{}); Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{});
clear(acc_s); clear(acc_s);
int col_idx = lane_idx / 16; int col_idx = lane_idx / 16;
...@@ -395,6 +397,11 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -395,6 +397,11 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
cute::gemm(tiled_mma_o, tOrP(_, _, 2), tOrVt(_, _, 2), acc_o); cute::gemm(tiled_mma_o, tOrP(_, _, 2), tOrVt(_, _, 2), acc_o);
cute::gemm(tiled_mma_o, tOrP(_, _, 3), tOrVt(_, _, 3), acc_o); cute::gemm(tiled_mma_o, tOrP(_, _, 3), tOrVt(_, _, 3), acc_o);
} }
};
MainloopArgs args = get_cur_req_info(batch_idx);
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; block_idx++) {
process_one_block(block_idx, IsOrigBlock{});
} }
if (args.is_no_split) { if (args.is_no_split) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment