Commit 17bb1aaa authored by ltqin's avatar ltqin
Browse files

add alpha for dV and change alpha for dK dQ

parent 4cbab521
...@@ -255,6 +255,12 @@ int run(int argc, char* argv[]) ...@@ -255,6 +255,12 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
float p_drop = 0.2;
float p_dropout = 1 - p_drop;
float rp_dropout = 1.0 / p_dropout;
float scale_rp_dropout = alpha * rp_dropout;
if(argc == 1) if(argc == 1)
{ {
...@@ -479,7 +485,7 @@ int run(int argc, char* argv[]) ...@@ -479,7 +485,7 @@ int run(int argc, char* argv[])
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
Scale{alpha}, Scale{scale_rp_dropout}, //dQ *= scale_rp_dropout
QKVElementOp{}, QKVElementOp{},
YElementOp{}); YElementOp{});
......
...@@ -21,8 +21,10 @@ struct BlockwiseDropout ...@@ -21,8 +21,10 @@ struct BlockwiseDropout
{ {
auto execute_dropout = [&](bool keep, DataType val) { auto execute_dropout = [&](bool keep, DataType val) {
return keep ? val * p_dropout_rescale if constexpr(using_sign_bit)
: (using_sign_bit ? -val * p_dropout_rescale : float(0)); return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
}; };
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat;
......
...@@ -742,28 +742,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -742,28 +742,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>; true>;
using ABlockwiseCopy_dV = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::Relu,
Sequence<Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(
I0), // ThreadSliceLengths
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I1),
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>;
template <typename GridDesc_M0_O_M1> template <typename GridDesc_M0_O_M1>
using BBlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1< using BBlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
...@@ -1401,10 +1379,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1401,10 +1379,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(vgrad_grid_desc_n_o); Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(vgrad_grid_desc_n_o);
// dV: A matrix VGPR-to-LDS blockwise copy // dV: A matrix VGPR-to-LDS blockwise copy
auto vgrad_gemm_tile_p_thread_copy_vgpr_to_lds = typename Gemm2::ABlockwiseCopy_dV{ auto vgrad_gemm_tile_p_thread_copy_vgpr_to_lds = typename Gemm2::ABlockwiseCopy{
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(), Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(),
tensor_operation::element_wise::Relu{}}; // relu(P-dropped) tensor_operation::element_wise::PassThrough{}};
// dV: B matrix global-to-LDS blockwise copy // dV: B matrix global-to-LDS blockwise copy
auto vgrad_gemm_tile_ygrad_blockwise_copy = auto vgrad_gemm_tile_ygrad_blockwise_copy =
......
...@@ -722,34 +722,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -722,34 +722,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1, typename Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1,
false>; false>;
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3< using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
DataType, DataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough, ElementwiseOp,
Sequence<Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(
I0), // ThreadSliceLengths
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I1),
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>;
using ABlockwiseCopy_dV = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::Relu,
Sequence<Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At( Sequence<Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(
I0), // ThreadSliceLengths I0), // ThreadSliceLengths
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I1), Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I1),
...@@ -1410,10 +1389,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1410,10 +1389,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(vgrad_grid_desc_n_o); Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(vgrad_grid_desc_n_o);
// dV: A matrix VGPR-to-LDS blockwise copy // dV: A matrix VGPR-to-LDS blockwise copy
auto vgrad_gemm_tile_p_thread_copy_vgpr_to_lds = typename Gemm2::ABlockwiseCopy_dV{ auto vgrad_gemm_tile_p_thread_copy_vgpr_to_lds =
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, typename Gemm2::template ABlockwiseCopy<tensor_operation::element_wise::Relu>{
Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(), Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
tensor_operation::element_wise::Relu{}}; // relu(P-dropped) Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(),
tensor_operation::element_wise::Relu{}}; // relu(P-dropped)
// dV: B matrix global-to-LDS blockwise copy // dV: B matrix global-to-LDS blockwise copy
auto vgrad_gemm_tile_ygrad_blockwise_copy = auto vgrad_gemm_tile_ygrad_blockwise_copy =
...@@ -1438,11 +1418,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1438,11 +1418,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
make_multi_index( make_multi_index(
I0, block_work_idx[I1] * Gemm2Params_N_O_M::GemmORepeat, I0, I0, I0, I0, I0, I0); I0, block_work_idx[I1] * Gemm2Params_N_O_M::GemmORepeat, I0, I0, I0, I0, I0, I0);
auto vgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<decltype( auto vgrad_thread_copy_vgpr_to_global =
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4)>( typename Gemm2::template CBlockwiseCopy<decltype(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4),
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4, tensor_operation::element_wise::Scale>(
tensor_operation::element_wise::PassThrough{}); vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
tensor_operation::element_wise::Scale{rp_dropout});
// dK: transform input and output tensor descriptors // dK: transform input and output tensor descriptors
const auto q_grid_desc_m0_k_m1 = const auto q_grid_desc_m0_k_m1 =
...@@ -1453,10 +1435,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1453,10 +1435,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(kgrad_grid_desc_n_k); Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(kgrad_grid_desc_n_k);
// dK: A matrix VGPR-to-LDS blockwise copy // dK: A matrix VGPR-to-LDS blockwise copy
auto kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds = typename Gemm2::ABlockwiseCopy{ auto kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds =
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, typename Gemm2::template ABlockwiseCopy<tensor_operation::element_wise::PassThrough>{
Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(), Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
tensor_operation::element_wise::PassThrough{}}; Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(),
tensor_operation::element_wise::PassThrough{}};
// dK: B matrix global-to-LDS blockwise copy // dK: B matrix global-to-LDS blockwise copy
auto kgrad_gemm_tile_q_blockwise_copy = auto kgrad_gemm_tile_q_blockwise_copy =
...@@ -1724,7 +1707,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1724,7 +1707,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>( blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph, gemm1_k_block_outer_index, num_gemm1_k_block_outer_loop); s_slash_p_thread_buf, ph);
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment