Commit 17bb1aaa authored by ltqin's avatar ltqin
Browse files

add alpha for dV and change alpha for dK dQ

parent 4cbab521
......@@ -256,6 +256,12 @@ int run(int argc, char* argv[])
bool input_permute = false;
bool output_permute = false;
float p_drop = 0.2;
float p_dropout = 1 - p_drop;
float rp_dropout = 1.0 / p_dropout;
float scale_rp_dropout = alpha * rp_dropout;
if(argc == 1)
{
// use default case
......@@ -479,7 +485,7 @@ int run(int argc, char* argv[])
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
Scale{scale_rp_dropout}, //dQ *= scale_rp_dropout
QKVElementOp{},
YElementOp{});
......
......@@ -21,8 +21,10 @@ struct BlockwiseDropout
{
auto execute_dropout = [&](bool keep, DataType val) {
return keep ? val * p_dropout_rescale
: (using_sign_bit ? -val * p_dropout_rescale : float(0));
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat;
......
......@@ -742,28 +742,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
1, // DstScalarStrideInVector
true>;
using ABlockwiseCopy_dV = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::Relu,
Sequence<Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(
I0), // ThreadSliceLengths
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I1),
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>;
template <typename GridDesc_M0_O_M1>
using BBlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
......@@ -1401,10 +1379,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(vgrad_grid_desc_n_o);
// dV: A matrix VGPR-to-LDS blockwise copy
auto vgrad_gemm_tile_p_thread_copy_vgpr_to_lds = typename Gemm2::ABlockwiseCopy_dV{
auto vgrad_gemm_tile_p_thread_copy_vgpr_to_lds = typename Gemm2::ABlockwiseCopy{
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(),
tensor_operation::element_wise::Relu{}}; // relu(P-dropped)
tensor_operation::element_wise::PassThrough{}};
// dV: B matrix global-to-LDS blockwise copy
auto vgrad_gemm_tile_ygrad_blockwise_copy =
......
......@@ -722,34 +722,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1,
false>;
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough,
Sequence<Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(
I0), // ThreadSliceLengths
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I1),
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>;
using ABlockwiseCopy_dV = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::Relu,
ElementwiseOp,
Sequence<Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(
I0), // ThreadSliceLengths
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I1),
......@@ -1410,7 +1389,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(vgrad_grid_desc_n_o);
// dV: A matrix VGPR-to-LDS blockwise copy
auto vgrad_gemm_tile_p_thread_copy_vgpr_to_lds = typename Gemm2::ABlockwiseCopy_dV{
auto vgrad_gemm_tile_p_thread_copy_vgpr_to_lds =
typename Gemm2::template ABlockwiseCopy<tensor_operation::element_wise::Relu>{
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(),
tensor_operation::element_wise::Relu{}}; // relu(P-dropped)
......@@ -1438,11 +1418,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
make_multi_index(
I0, block_work_idx[I1] * Gemm2Params_N_O_M::GemmORepeat, I0, I0, I0, I0, I0, I0);
auto vgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<decltype(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4)>(
auto vgrad_thread_copy_vgpr_to_global =
typename Gemm2::template CBlockwiseCopy<decltype(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4),
tensor_operation::element_wise::Scale>(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
tensor_operation::element_wise::PassThrough{});
tensor_operation::element_wise::Scale{rp_dropout});
// dK: transform input and output tensor descriptors
const auto q_grid_desc_m0_k_m1 =
......@@ -1453,7 +1435,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(kgrad_grid_desc_n_k);
// dK: A matrix VGPR-to-LDS blockwise copy
auto kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds = typename Gemm2::ABlockwiseCopy{
auto kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds =
typename Gemm2::template ABlockwiseCopy<tensor_operation::element_wise::PassThrough>{
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(),
tensor_operation::element_wise::PassThrough{}};
......@@ -1724,7 +1707,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph, gemm1_k_block_outer_index, num_gemm1_k_block_outer_loop);
s_slash_p_thread_buf, ph);
block_sync_lds(); // wait for gemm1 LDS read
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment