Commit 91691124 authored by zhanghj2's avatar zhanghj2
Browse files

优化combine

parent c4412432
...@@ -40,7 +40,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) { ...@@ -40,7 +40,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
return; return;
} }
FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS); // FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
Tensor gLseAccum = make_tensor( Tensor gLseAccum = make_tensor(
make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M), make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M),
...@@ -127,6 +127,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) { ...@@ -127,6 +127,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
__syncthreads(); __syncthreads();
static_assert(HEAD_DIM_V % (64*4) == 0); static_assert(HEAD_DIM_V % (64*4) == 0);
constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (64*4); constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (64*4);
static_assert(ELEMS_PER_THREAD == 2);
float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q; float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q;
float4 datas[ELEMS_PER_THREAD]; float4 datas[ELEMS_PER_THREAD];
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
...@@ -165,24 +166,52 @@ flash_fwd_mla_combine_kernel(const CombineParams params) { ...@@ -165,24 +166,52 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
// printf(" %.3f \n", result[0].x); // printf(" %.3f \n", result[0].x);
// } // }
const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx; const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q; ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q + lane_idx * 8;
ElementT data_converted[8];
CUTLASS_PRAGMA_UNROLL using result_type = cutlass::Array<ElementT, 2>;
for (int i = 0; i < ELEMS_PER_THREAD; ++i) { for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
float4 data = result[i]; if constexpr(std::is_same_v<cutlass::bfloat16_t, ElementT>) {
ElementT data_converted[4]; #if defined(__gfx938__)
// auto res = __builtin_hcu_cvt_pk_bf16_f32(0, data.x, 0, data.y, 0); auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, result[i].x, 0, result[i].y, 0);
// data_converted[0].storage = res[0]; auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, result[i].z, 0, result[i].w, 0);
// data_converted[1].storage = res[1]; auto res0 = reinterpret_cast<result_type const &>(d0);
// res = __builtin_hcu_cvt_pk_bf16_f32(0, data.z, 0, data.w, 0); auto res1 = reinterpret_cast<result_type const &>(d1);
// data_converted[2].storage = res[0]; o_ptr[i * 4] = res0[0];
// data_converted[3].storage = res[1]; o_ptr[i * 4 + 1] = res0[1];
data_converted[0] = (ElementT)(data.x); o_ptr[i * 4 + 2] = res1[0];
data_converted[1] = (ElementT)(data.y); o_ptr[i * 4 + 3] = res1[1];
data_converted[2] = (ElementT)(data.z); #else
data_converted[3] = (ElementT)(data.w); // auto float32_to_bfloat16 = [&](float v) -> ElementT {
static_assert(sizeof(ElementT) == 2); // union {
*(uint64_t*)(o_ptr + lane_idx*8 + i*4) = *(uint64_t*)data_converted; // float fp32;
// uint32_t int32;
// } u = {v};
// ElementT res;
// res.storage = (u.int32 >> 16);
// return res;
// };
// float4 data = result[i];
// o_ptr[i * 4] = float32_to_bfloat16((data.x));
// o_ptr[i * 4 + 1] = float32_to_bfloat16((data.y));
// o_ptr[i * 4 + 2] = float32_to_bfloat16((data.z));
// o_ptr[i * 4 + 3] = float32_to_bfloat16((data.w));
data_converted[i * 4] = (ElementT)(data.x);
data_converted[i * 4 + 1] = (ElementT)(data.y);
data_converted[i * 4 + 2] = (ElementT)(data.z);
data_converted[i * 4 + 3] = (ElementT)(data.w);
#endif
} else {
auto d0 = __builtin_hcu_cvt_pkrtz(result[i].x, result[i].y);
auto d1 = __builtin_hcu_cvt_pkrtz(result[i].z, result[i].w);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
o_ptr[i * 4] = res0[0];
o_ptr[i * 4 + 1] = res0[1];
o_ptr[i * 4 + 2] = res1[0];
o_ptr[i * 4 + 3] = res1[1];
}
} }
} }
} }
......
...@@ -96,7 +96,7 @@ get_mla_metadata_kernel(const GetDecodeSchedMetaParams params) { ...@@ -96,7 +96,7 @@ get_mla_metadata_kernel(const GetDecodeSchedMetaParams params) {
} }
tile_scheduler_metadata_ptr[i] = cur_meta; tile_scheduler_metadata_ptr[i] = cur_meta;
} }
FLASH_DEVICE_ASSERT(now_req_idx == batch_size && now_block == 0 && now_n_split_idx == 0); // FLASH_DEVICE_ASSERT(now_req_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
} }
__syncthreads(); __syncthreads();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment