Commit c860a754 authored by danyao12's avatar danyao12
Browse files

optimized code

parent 0286b9bf
...@@ -1945,14 +1945,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1945,14 +1945,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
if(c0_matrix_mask.IsMaskedElement(m_global, n_global)) bool masked_flag = c0_matrix_mask.IsMaskedElement(m_global, n_global);
{ s_element_op(s_slash_p_thread_buf(i),
s_slash_p_thread_buf(i) = -ck::NumericLimits<float>::Infinity(); masked_flag ? -ck::NumericLimits<float>::Infinity()
} : s_slash_p_thread_buf[i]);
else
{
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
}
}); });
} }
else else
...@@ -2015,17 +2011,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2015,17 +2011,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
constexpr auto m = constexpr auto m =
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0]; pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0];
// dS and P has same thread buf layout // dS and P has same thread buf layout
if(s_slash_p_thread_buf[i] >= 0) bool undropped_flag = s_slash_p_thread_buf[i] >= 0;
{
sgrad_thread_buf(i) = sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] * s_slash_p_thread_buf[i] *
(pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]); (undropped_flag ? (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}])
} : y_dot_ygrad_thread_buf[Number<m>{}]);
else
{
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] * y_dot_ygrad_thread_buf[Number<m>{}];
}
}); });
// gemm dQ // gemm dQ
...@@ -2086,6 +2076,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2086,6 +2076,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
p_slice_idx[I3], p_slice_idx[I3],
p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3)); p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run( vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run(
...@@ -2096,8 +2087,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2096,8 +2087,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
gemm2_a_block_buf); gemm2_a_block_buf);
} }
// block_sync_lds(); // sync before write
vgrad_gemm_tile_ygrad_blockwise_copy.Run(Gemm2::b_block_desc_o0_o1_o2_m0_m1_m2_m3, vgrad_gemm_tile_ygrad_blockwise_copy.Run(Gemm2::b_block_desc_o0_o1_o2_m0_m1_m2_m3,
ygrad_block_buf, ygrad_block_buf,
Gemm2::b_thread_desc_o0_o1_o2_m0_m1_m2_m3, Gemm2::b_thread_desc_o0_o1_o2_m0_m1_m2_m3,
...@@ -2135,6 +2124,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2135,6 +2124,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
sgrad_slice_idx[I3] + sgrad_slice_idx[I3] +
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3)); Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run( kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
......
...@@ -1840,14 +1840,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1840,14 +1840,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
if(c0_matrix_mask.IsMaskedElement(m_global, n_global)) bool masked_flag = c0_matrix_mask.IsMaskedElement(m_global, n_global);
{ s_element_op(s_slash_p_thread_buf(i),
s_slash_p_thread_buf(i) = -ck::NumericLimits<float>::Infinity(); masked_flag ? -ck::NumericLimits<float>::Infinity()
} : s_slash_p_thread_buf[i]);
else
{
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
}
}); });
} }
else else
...@@ -1924,6 +1920,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1924,6 +1920,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
p_slice_idx[I3], p_slice_idx[I3],
p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3)); p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run( vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run(
...@@ -1939,7 +1936,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1939,7 +1936,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow( vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1, Gemm2::b_block_slice_copy_step); ygrad_grid_desc_m0_o_m1, Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
vgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1, vgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
gemm2_b_block_buf); gemm2_b_block_buf);
...@@ -1987,17 +1983,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1987,17 +1983,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
constexpr auto m = constexpr auto m =
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0]; pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0];
// dS and P has same thread buf layout // dS and P has same thread buf layout
if(s_slash_p_thread_buf[i] >= 0) bool undropped_flag = s_slash_p_thread_buf[i] >= 0;
{
sgrad_thread_buf(i) = sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] * s_slash_p_thread_buf[i] *
(pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]); (undropped_flag ? (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}])
} : y_dot_ygrad_thread_buf[Number<m>{}]);
else
{
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] * y_dot_ygrad_thread_buf[Number<m>{}];
}
}); });
// gemm dQ // gemm dQ
...@@ -2082,6 +2072,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2082,6 +2072,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
sgrad_slice_idx[I3] + sgrad_slice_idx[I3] +
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3)); Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run( kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
...@@ -2098,7 +2089,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2098,7 +2089,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1, kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1,
Gemm2::b_block_slice_copy_step); Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
kgrad_gemm_tile_q_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1, kgrad_gemm_tile_q_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
gemm2_b_block_buf); gemm2_b_block_buf);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment