"...resnet50_tensorflow.git" did not exist on "dd5ee3bbbce83b396e8e95692686245247115760"
Commit b7b7e153 authored by danyao12's avatar danyao12
Browse files

Merge branch 'mha-train-develop-dropout8bit' into mha-train-develop-bwdopt-bias

parents 5033eee6 7a8352bc
......@@ -46,7 +46,8 @@ __global__ void
const DGridDescriptor_M d_grid_desc_m,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const float p_drop)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......@@ -66,7 +67,8 @@ __global__ void
p_d_grid + d_batch_offset,
y_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_m,
block_2_ctile_map);
block_2_ctile_map,
p_drop);
#else
ignore = p_y_grid;
......@@ -77,6 +79,7 @@ __global__ void
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
ignore = p_drop;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -1131,7 +1134,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
arg.d_grid_desc_m_,
arg.d_block_2_ctile_map_,
arg.batch_count_,
arg.compute_base_ptr_of_batch_);
arg.compute_base_ptr_of_batch_,
arg.p_drop_);
};
ave_time = launch_kernel();
......
......@@ -46,7 +46,8 @@ __global__ void
const DGridDescriptor_M d_grid_desc_m,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const float p_drop)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......@@ -66,7 +67,8 @@ __global__ void
p_d_grid + d_batch_offset,
y_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_m,
block_2_ctile_map);
block_2_ctile_map,
p_drop);
#else
ignore = p_y_grid;
......@@ -77,6 +79,7 @@ __global__ void
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
ignore = p_drop;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -1143,7 +1146,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
arg.d_grid_desc_m_,
arg.d_block_2_ctile_map_,
arg.batch_count_,
arg.compute_base_ptr_of_batch_);
arg.compute_base_ptr_of_batch_,
arg.p_drop_);
};
ave_time = launch_kernel();
......
......@@ -33,7 +33,9 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_grouped_multihead_attention_backward_ydotygrad_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count)
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const float p_dropout)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......@@ -74,10 +76,12 @@ __global__ void
arg_ptr[group_id].p_d_grid_ + d_batch_offset,
arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m_,
arg_ptr[group_id].d_block_2_ctile_map_);
arg_ptr[group_id].d_block_2_ctile_map_,
p_dropout);
#else
ignore = group_kernel_args;
ignore = group_count;
ignore = p_dropout;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -1175,7 +1179,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.group_count_);
arg.group_count_,
arg.p_dropout_);
};
ave_time = launch_kernel();
}
......
......@@ -32,7 +32,9 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_grouped_multihead_attention_backward_ydotygrad_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count)
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const float p_dropout)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......@@ -73,10 +75,12 @@ __global__ void
arg_ptr[group_id].p_d_grid_ + d_batch_offset,
arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m_,
arg_ptr[group_id].d_block_2_ctile_map_);
arg_ptr[group_id].d_block_2_ctile_map_,
p_dropout);
#else
ignore = group_kernel_args;
ignore = group_count;
ignore = p_dropout;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -1244,7 +1248,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.group_count_);
arg.group_count_,
arg.p_dropout_);
};
ave_time = launch_kernel();
}
......
......@@ -142,8 +142,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
y_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDesc_M& d_grid_desc_m,
const Block2CTileMap& block_2_ctile_map)
const Block2CTileMap& block_2_ctile_map,
const float p_drop)
{
const FloatD p_dropout = type_convert<FloatD>(1.0f - p_drop);
const tensor_operation::element_wise::Scale scale_p_dropout(p_dropout);
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -247,7 +251,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
FloatD,
decltype(d_thread_desc_mblock_m1),
decltype(d_grid_desc_mblock_mperblock),
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Scale,
Sequence<1, 1>,
Sequence<0, 1>,
1,
......@@ -258,7 +262,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
d_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx_m, // mblock
get_thread_local_1d_id()), // mperblock
ck::tensor_operation::element_wise::PassThrough{}};
scale_p_dropout};
// copy from VGPR to Global
d_thread_copy_vgpr_to_global.Run(d_thread_desc_mblock_m1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment