Commit 32b03f33 authored by danyao12's avatar danyao12
Browse files

dropout sync with pt2

parent 889c6bd5
...@@ -1265,7 +1265,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1265,7 +1265,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
{ {
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); const ushort p_dropout_in_16bits =
__builtin_amdgcn_readfirstlane(std::floor(p_dropout * 65535.0));
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() * const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout); rp_dropout);
...@@ -1520,7 +1521,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1520,7 +1521,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
n0, // NRepeat I1, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -1556,7 +1557,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1556,7 +1557,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
Sequence<I1, // MBlockId Sequence<I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
n0, // NRepeat I1, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -1958,19 +1959,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1958,19 +1959,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
if(p_z_grid) if(p_z_grid)
{ {
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), static_for<0, n0, 1>{}([&](auto i) {
decltype(z_tenor_buffer), blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
true>( decltype(z_tenor_buffer),
s_slash_p_thread_buf, ph, z_tenor_buffer); true,
decltype(n0),
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, decltype(i)>(
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), s_slash_p_thread_buf, ph, z_tenor_buffer);
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_copy_vgpr_to_global.Run(
z_grid_buf); z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
} }
else else
{ {
ignore = z_grid_buf;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>( blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph); s_slash_p_thread_buf, ph);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment