Unverified Commit 21ef37b4 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #889 from ROCmSoftwarePlatform/mha-train-develop-bwdopt-bias

Mha train develop bwdopt bias
parents 1f04cd2b db579ac9
...@@ -142,8 +142,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -142,8 +142,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
y_grid_desc_mblock_mperblock_nblock_nperblock, y_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDesc_M& d_grid_desc_m, const DGridDesc_M& d_grid_desc_m,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map,
const float p_drop)
{ {
const FloatD p_dropout = type_convert<FloatD>(1.0f - p_drop);
const tensor_operation::element_wise::Scale scale_p_dropout(p_dropout);
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_y_grid, y_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -247,7 +251,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -247,7 +251,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
FloatD, FloatD,
decltype(d_thread_desc_mblock_m1), decltype(d_thread_desc_mblock_m1),
decltype(d_grid_desc_mblock_mperblock), decltype(d_grid_desc_mblock_mperblock),
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::Scale,
Sequence<1, 1>, Sequence<1, 1>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
...@@ -258,7 +262,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -258,7 +262,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
d_grid_desc_mblock_mperblock, d_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx_m, // mblock make_multi_index(block_work_idx_m, // mblock
get_thread_local_1d_id()), // mperblock get_thread_local_1d_id()), // mperblock
ck::tensor_operation::element_wise::PassThrough{}}; scale_p_dropout};
// copy from VGPR to Global // copy from VGPR to Global
d_thread_copy_vgpr_to_global.Run(d_thread_desc_mblock_m1, d_thread_copy_vgpr_to_global.Run(d_thread_desc_mblock_m1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment