Unverified Commit d2eed8e6 authored by guangzlu's avatar guangzlu Committed by GitHub
Browse files

Merge branch 'attn-bwd-dropout' into fwd-drop-verify2

parents 043c8ff3 e9e6081a
...@@ -10,7 +10,8 @@ int run(int argc, char* argv[]) ...@@ -10,7 +10,8 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
float p_drop = 0.1;
float p_drop = 0.2;
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
...@@ -251,6 +252,7 @@ int run(int argc, char* argv[]) ...@@ -251,6 +252,7 @@ int run(int argc, char* argv[])
{seed, offset}); // dropout random seed and offset, offset should be {seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread // at least the number of elements on a thread
// specify workspace for problem_desc // specify workspace for problem_desc
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
......
...@@ -623,6 +623,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -623,6 +623,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; // z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
const auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = const auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
z_grid_desc_m_n); z_grid_desc_m_n);
......
...@@ -1019,6 +1019,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1019,6 +1019,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment