Commit 91691124 authored by zhanghj2's avatar zhanghj2
Browse files

优化combine

parent c4412432
......@@ -40,7 +40,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
return;
}
FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
// FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
Tensor gLseAccum = make_tensor(
make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M),
......@@ -127,6 +127,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
__syncthreads();
static_assert(HEAD_DIM_V % (64*4) == 0);
constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (64*4);
static_assert(ELEMS_PER_THREAD == 2);
float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q;
float4 datas[ELEMS_PER_THREAD];
CUTLASS_PRAGMA_UNROLL
......@@ -165,24 +166,52 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
// printf(" %.3f \n", result[0].x);
// }
const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q;
CUTLASS_PRAGMA_UNROLL
ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q + lane_idx * 8;
ElementT data_converted[8];
using result_type = cutlass::Array<ElementT, 2>;
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
float4 data = result[i];
ElementT data_converted[4];
// auto res = __builtin_hcu_cvt_pk_bf16_f32(0, data.x, 0, data.y, 0);
// data_converted[0].storage = res[0];
// data_converted[1].storage = res[1];
// res = __builtin_hcu_cvt_pk_bf16_f32(0, data.z, 0, data.w, 0);
// data_converted[2].storage = res[0];
// data_converted[3].storage = res[1];
data_converted[0] = (ElementT)(data.x);
data_converted[1] = (ElementT)(data.y);
data_converted[2] = (ElementT)(data.z);
data_converted[3] = (ElementT)(data.w);
static_assert(sizeof(ElementT) == 2);
*(uint64_t*)(o_ptr + lane_idx*8 + i*4) = *(uint64_t*)data_converted;
if constexpr(std::is_same_v<cutlass::bfloat16_t, ElementT>) {
#if defined(__gfx938__)
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, result[i].x, 0, result[i].y, 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, result[i].z, 0, result[i].w, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
o_ptr[i * 4] = res0[0];
o_ptr[i * 4 + 1] = res0[1];
o_ptr[i * 4 + 2] = res1[0];
o_ptr[i * 4 + 3] = res1[1];
#else
// auto float32_to_bfloat16 = [&](float v) -> ElementT {
// union {
// float fp32;
// uint32_t int32;
// } u = {v};
// ElementT res;
// res.storage = (u.int32 >> 16);
// return res;
// };
// float4 data = result[i];
// o_ptr[i * 4] = float32_to_bfloat16((data.x));
// o_ptr[i * 4 + 1] = float32_to_bfloat16((data.y));
// o_ptr[i * 4 + 2] = float32_to_bfloat16((data.z));
// o_ptr[i * 4 + 3] = float32_to_bfloat16((data.w));
data_converted[i * 4] = (ElementT)(data.x);
data_converted[i * 4 + 1] = (ElementT)(data.y);
data_converted[i * 4 + 2] = (ElementT)(data.z);
data_converted[i * 4 + 3] = (ElementT)(data.w);
#endif
} else {
auto d0 = __builtin_hcu_cvt_pkrtz(result[i].x, result[i].y);
auto d1 = __builtin_hcu_cvt_pkrtz(result[i].z, result[i].w);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
o_ptr[i * 4] = res0[0];
o_ptr[i * 4 + 1] = res0[1];
o_ptr[i * 4 + 2] = res1[0];
o_ptr[i * 4 + 3] = res1[1];
}
}
}
}
......
......@@ -96,7 +96,7 @@ get_mla_metadata_kernel(const GetDecodeSchedMetaParams params) {
}
tile_scheduler_metadata_ptr[i] = cur_meta;
}
FLASH_DEVICE_ASSERT(now_req_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
// FLASH_DEVICE_ASSERT(now_req_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
}
__syncthreads();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment