Commit 945ced44 authored by zhanghj2's avatar zhanghj2
Browse files

修复gfx936 bug

parent 60dfab33
......@@ -165,6 +165,12 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
// {
// printf(" %.3f \n", result[0].x);
// }
auto float2bf16 = [] (float s) -> uint16_t {
uint32_t x32 = reinterpret_cast<uint32_t const &>(s);
x32 += 0x8000u;
return uint16_t(x32 >> 16);
};
const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q + lane_idx * 8;
ElementT data_converted[8];
......@@ -196,10 +202,10 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
// o_ptr[i * 4 + 1] = float32_to_bfloat16((data.y));
// o_ptr[i * 4 + 2] = float32_to_bfloat16((data.z));
// o_ptr[i * 4 + 3] = float32_to_bfloat16((data.w));
data_converted[i * 4] = (ElementT)(data.x);
data_converted[i * 4 + 1] = (ElementT)(data.y);
data_converted[i * 4 + 2] = (ElementT)(data.z);
data_converted[i * 4 + 3] = (ElementT)(data.w);
o_ptr[i * 4].storage = float2bf16(data.x);
o_ptr[i * 4 + 1].storage = float2bf16(data.y);
o_ptr[i * 4 + 2].storage = float2bf16(data.z);
o_ptr[i * 4 + 3].storage = float2bf16(data.w);
#endif
} else {
auto d0 = __builtin_hcu_cvt_pkrtz(result[i].x, result[i].y);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment