Commit 7e402f6a authored by guangzlu's avatar guangzlu
Browse files

added switch for dropout in bwd pass

parent 665b08cf
......@@ -47,7 +47,8 @@ template <typename GridwiseGemm,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
bool HasMainKBlockLoop>
bool HasMainKBlockLoop,
bool IsDropout>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
......@@ -111,7 +112,8 @@ __global__ void
ck::philox ph(seed, global_thread_id, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
......@@ -786,6 +788,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
y_grid_desc_m_o_);
}
is_dropout_ = p_drop_ > 0.0;
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
......@@ -877,6 +880,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_drop_;
bool is_dropout_;
unsigned long long seed_;
unsigned long long offset_;
};
......@@ -898,7 +902,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm,
DataType,
......@@ -920,7 +924,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_>;
has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel(stream_config,
kernel,
......@@ -970,7 +975,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// {
// ave_time = launch_kernel(integral_constant<bool, false>{});
// }
ave_time = launch_kernel(integral_constant<bool, false>{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
// ave_time = launch_kernel(integral_constant<bool, false>{});
#endif
return ave_time;
}
......
......@@ -46,7 +46,8 @@ template <typename GridwiseGemm,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
bool HasMainKBlockLoop>
bool HasMainKBlockLoop,
bool IsDropout>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
......@@ -110,7 +111,8 @@ __global__ void
ck::philox ph(seed, global_thread_id, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
......@@ -784,6 +786,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
y_grid_desc_m_o_);
}
is_dropout_ = p_drop_ > 0.0;
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
......@@ -875,6 +878,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_drop_;
bool is_dropout_;
unsigned long long seed_;
unsigned long long offset_;
};
......@@ -900,7 +904,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm,
DataType,
......@@ -922,7 +926,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_>;
has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel(stream_config,
kernel,
......@@ -966,11 +971,29 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
#if 1
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
ave_time = launch_kernel(integral_constant<bool, true>{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
#endif
return ave_time;
......
......@@ -34,7 +34,8 @@ template <typename GridwiseGemm,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop>
bool HasMainKBlockLoop,
bool IsDropout>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
......@@ -99,7 +100,7 @@ __global__ void
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>(
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr,
......@@ -685,6 +686,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
c_element_op_{c_element_op},
p_dropout_{p_drop}
{
is_dropout_ = p_drop > 0.0;
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
......@@ -867,6 +869,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
CElementwiseOperation c_element_op_;
float p_dropout_;
bool is_dropout_;
unsigned long long seed_;
unsigned long long offset_;
......@@ -908,7 +911,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm,
GroupKernelArg,
......@@ -917,7 +920,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_>;
has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel(
stream_config,
......@@ -941,11 +945,29 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// to concern Gemm0's loop
if(all_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
}
else if(!some_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, false>{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
else
{
......
......@@ -34,7 +34,8 @@ template <typename GridwiseGemm,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop>
bool HasMainKBlockLoop,
bool IsDropout>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
......@@ -99,7 +100,7 @@ __global__ void
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>(
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr,
......@@ -678,6 +679,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
c_element_op_{c_element_op},
p_dropout_{p_drop}
{
is_dropout_ = p_drop > 0.0;
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
......@@ -860,6 +862,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
CElementwiseOperation c_element_op_;
float p_dropout_;
bool is_dropout_;
unsigned long long seed_;
unsigned long long offset_;
......@@ -900,7 +903,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm,
GroupKernelArg,
......@@ -909,7 +912,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_>;
has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel(
stream_config,
......@@ -933,11 +937,29 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// to concern Gemm0's loop
if(all_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
}
else if(!some_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, false>{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
else
{
......
......@@ -1230,6 +1230,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap,
typename C0MatrixMask,
typename VGradGridDescriptor_N_O,
......@@ -1957,6 +1958,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
if constexpr(IsDropout)
{
// save z to global
if(p_z_grid)
{
......@@ -1982,6 +1985,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
else
{
......@@ -1990,6 +1996,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
}
}
block_sync_lds(); // wait for gemm1 LDS read
......@@ -2176,9 +2183,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Gemm2::b_block_reset_copy_step); // rewind M
kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
......
......@@ -1140,6 +1140,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap,
typename C0MatrixMask,
typename VGradGridDescriptor_N_O,
......@@ -1852,6 +1853,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
if constexpr(IsDropout)
{
// save z to global
if(p_z_grid)
{
......@@ -1877,6 +1880,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
else
{
......@@ -1885,6 +1891,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
}
}
block_sync_lds(); // wait for gemm1 LDS read
......@@ -2126,9 +2133,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment