Commit 7e402f6a authored by guangzlu's avatar guangzlu
Browse files

added switch for dropout in bwd pass

parent 665b08cf
...@@ -47,7 +47,8 @@ template <typename GridwiseGemm, ...@@ -47,7 +47,8 @@ template <typename GridwiseGemm,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask, typename C0MatrixMask,
bool HasMainKBlockLoop> bool HasMainKBlockLoop,
bool IsDropout>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
...@@ -111,7 +112,8 @@ __global__ void ...@@ -111,7 +112,8 @@ __global__ void
ck::philox ph(seed, global_thread_id, offset); ck::philox ph(seed, global_thread_id, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset); ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
...@@ -786,6 +788,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -786,6 +788,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
y_grid_desc_m_o_); y_grid_desc_m_o_);
} }
is_dropout_ = p_drop_ > 0.0;
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -877,6 +880,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -877,6 +880,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_drop_; float p_drop_;
bool is_dropout_;
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
}; };
...@@ -898,7 +902,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -898,7 +902,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1< const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
DataType, DataType,
...@@ -920,7 +924,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -920,7 +924,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask, C0MatrixMask,
has_main_k_block_loop_>; has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -970,7 +975,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -970,7 +975,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// { // {
// ave_time = launch_kernel(integral_constant<bool, false>{}); // ave_time = launch_kernel(integral_constant<bool, false>{});
// } // }
ave_time = launch_kernel(integral_constant<bool, false>{}); if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
// ave_time = launch_kernel(integral_constant<bool, false>{});
#endif #endif
return ave_time; return ave_time;
} }
......
...@@ -46,7 +46,8 @@ template <typename GridwiseGemm, ...@@ -46,7 +46,8 @@ template <typename GridwiseGemm,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask, typename C0MatrixMask,
bool HasMainKBlockLoop> bool HasMainKBlockLoop,
bool IsDropout>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
...@@ -110,7 +111,8 @@ __global__ void ...@@ -110,7 +111,8 @@ __global__ void
ck::philox ph(seed, global_thread_id, offset); ck::philox ph(seed, global_thread_id, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset); ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
...@@ -784,6 +786,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -784,6 +786,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
y_grid_desc_m_o_); y_grid_desc_m_o_);
} }
is_dropout_ = p_drop_ > 0.0;
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -875,6 +878,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -875,6 +878,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_drop_; float p_drop_;
bool is_dropout_;
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
}; };
...@@ -900,7 +904,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -900,7 +904,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2< const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
DataType, DataType,
...@@ -922,7 +926,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -922,7 +926,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask, C0MatrixMask,
has_main_k_block_loop_>; has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -966,11 +971,29 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -966,11 +971,29 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
#if 1 #if 1
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
} }
#endif #endif
return ave_time; return ave_time;
......
...@@ -34,7 +34,8 @@ template <typename GridwiseGemm, ...@@ -34,7 +34,8 @@ template <typename GridwiseGemm,
typename AccElementwiseOperation, typename AccElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop> bool HasMainKBlockLoop,
bool IsDropout>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
...@@ -99,7 +100,7 @@ __global__ void ...@@ -99,7 +100,7 @@ __global__ void
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -685,6 +686,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -685,6 +686,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
c_element_op_{c_element_op}, c_element_op_{c_element_op},
p_dropout_{p_drop} p_dropout_{p_drop}
{ {
is_dropout_ = p_drop > 0.0;
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -867,6 +869,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -867,6 +869,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
float p_dropout_; float p_dropout_;
bool is_dropout_;
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
...@@ -908,7 +911,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -908,7 +911,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1< const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
GroupKernelArg, GroupKernelArg,
...@@ -917,7 +920,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -917,7 +920,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_>; has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
...@@ -941,11 +945,29 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -941,11 +945,29 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// to concern Gemm0's loop // to concern Gemm0's loop
if(all_has_main_k_block_loop) if(all_has_main_k_block_loop)
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
} }
else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
} }
else else
{ {
......
...@@ -34,7 +34,8 @@ template <typename GridwiseGemm, ...@@ -34,7 +34,8 @@ template <typename GridwiseGemm,
typename AccElementwiseOperation, typename AccElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop> bool HasMainKBlockLoop,
bool IsDropout>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
...@@ -99,7 +100,7 @@ __global__ void ...@@ -99,7 +100,7 @@ __global__ void
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -678,6 +679,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -678,6 +679,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
c_element_op_{c_element_op}, c_element_op_{c_element_op},
p_dropout_{p_drop} p_dropout_{p_drop}
{ {
is_dropout_ = p_drop > 0.0;
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -860,6 +862,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -860,6 +862,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
float p_dropout_; float p_dropout_;
bool is_dropout_;
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
...@@ -900,7 +903,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -900,7 +903,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2< const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
GroupKernelArg, GroupKernelArg,
...@@ -909,7 +912,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -909,7 +912,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_>; has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
...@@ -933,11 +937,29 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -933,11 +937,29 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// to concern Gemm0's loop // to concern Gemm0's loop
if(all_has_main_k_block_loop) if(all_has_main_k_block_loop)
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
} }
else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
} }
else else
{ {
......
...@@ -1230,6 +1230,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1230,6 +1230,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
} }
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap, typename Block2CTileMap,
typename C0MatrixMask, typename C0MatrixMask,
typename VGradGridDescriptor_N_O, typename VGradGridDescriptor_N_O,
...@@ -1957,6 +1958,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1957,6 +1958,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// scaling is already performed in the preceding statements with s_element_op // scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf); blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
if constexpr(IsDropout)
{
// save z to global // save z to global
if(p_z_grid) if(p_z_grid)
{ {
...@@ -1982,6 +1985,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1982,6 +1985,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} }
else else
{ {
...@@ -1990,6 +1996,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1990,6 +1996,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>( blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph); s_slash_p_thread_buf, ph);
} }
}
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
...@@ -2176,9 +2183,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2176,9 +2183,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Gemm2::b_block_reset_copy_step); // rewind M Gemm2::b_block_reset_copy_step); // rewind M
kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow( kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( // z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, // z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); // make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
......
...@@ -1140,6 +1140,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1140,6 +1140,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
} }
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap, typename Block2CTileMap,
typename C0MatrixMask, typename C0MatrixMask,
typename VGradGridDescriptor_N_O, typename VGradGridDescriptor_N_O,
...@@ -1852,6 +1853,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1852,6 +1853,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// scaling is already performed in the preceding statements with s_element_op // scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf); blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
if constexpr(IsDropout)
{
// save z to global // save z to global
if(p_z_grid) if(p_z_grid)
{ {
...@@ -1877,6 +1880,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1877,6 +1880,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} }
else else
{ {
...@@ -1885,6 +1891,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1885,6 +1891,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>( blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph); s_slash_p_thread_buf, ph);
} }
}
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
...@@ -2126,9 +2133,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2126,9 +2133,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow( kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( // z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, // z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); // make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment