Commit 7e402f6a authored by guangzlu's avatar guangzlu
Browse files

added switch for dropout in bwd pass

parent 665b08cf
......@@ -47,7 +47,8 @@ template <typename GridwiseGemm,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
bool HasMainKBlockLoop>
bool HasMainKBlockLoop,
bool IsDropout>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
......@@ -111,34 +112,35 @@ __global__ void
ck::philox ph(seed, global_thread_id, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
vgrad_grid_desc_n_o,
ygrad_grid_desc_o0_m_o1,
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph);
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
vgrad_grid_desc_n_o,
ygrad_grid_desc_o0_m_o1,
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph);
#else
ignore = p_a_grid;
ignore = p_b_grid;
......@@ -786,8 +788,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
y_grid_desc_m_o_);
}
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
is_dropout_ = p_drop_ > 0.0;
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
......@@ -877,6 +880,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_drop_;
bool is_dropout_;
unsigned long long seed_;
unsigned long long offset_;
};
......@@ -898,7 +902,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm,
DataType,
......@@ -920,7 +924,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_>;
has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel(stream_config,
kernel,
......@@ -970,7 +975,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// {
// ave_time = launch_kernel(integral_constant<bool, false>{});
// }
ave_time = launch_kernel(integral_constant<bool, false>{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
// ave_time = launch_kernel(integral_constant<bool, false>{});
#endif
return ave_time;
}
......
......@@ -46,7 +46,8 @@ template <typename GridwiseGemm,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
bool HasMainKBlockLoop>
bool HasMainKBlockLoop,
bool IsDropout>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
......@@ -110,34 +111,35 @@ __global__ void
ck::philox ph(seed, global_thread_id, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
vgrad_grid_desc_n_o,
ygrad_grid_desc_m0_o_m1,
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph);
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
vgrad_grid_desc_n_o,
ygrad_grid_desc_m0_o_m1,
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph);
#else
ignore = p_a_grid;
ignore = p_b_grid;
......@@ -784,8 +786,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
y_grid_desc_m_o_);
}
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
is_dropout_ = p_drop_ > 0.0;
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
......@@ -875,6 +878,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_drop_;
bool is_dropout_;
unsigned long long seed_;
unsigned long long offset_;
};
......@@ -900,7 +904,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm,
DataType,
......@@ -922,7 +926,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_>;
has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel(stream_config,
kernel,
......@@ -966,11 +971,29 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
#if 1
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
ave_time = launch_kernel(integral_constant<bool, true>{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
#endif
return ave_time;
......
......@@ -34,7 +34,8 @@ template <typename GridwiseGemm,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop>
bool HasMainKBlockLoop,
bool IsDropout>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
......@@ -99,7 +100,7 @@ __global__ void
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>(
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr,
......@@ -685,8 +686,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
c_element_op_{c_element_op},
p_dropout_{p_drop}
{
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
is_dropout_ = p_drop > 0.0;
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
group_count_ = ck::type_convert<ck::index_t>(problem_desc_vec.size());
......@@ -867,6 +869,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
CElementwiseOperation c_element_op_;
float p_dropout_;
bool is_dropout_;
unsigned long long seed_;
unsigned long long offset_;
......@@ -908,7 +911,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm,
GroupKernelArg,
......@@ -917,7 +920,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_>;
has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel(
stream_config,
......@@ -941,11 +945,29 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// to concern Gemm0's loop
if(all_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
}
else if(!some_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, false>{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
else
{
......
......@@ -34,7 +34,8 @@ template <typename GridwiseGemm,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop>
bool HasMainKBlockLoop,
bool IsDropout>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
......@@ -99,7 +100,7 @@ __global__ void
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>(
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr,
......@@ -678,8 +679,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
c_element_op_{c_element_op},
p_dropout_{p_drop}
{
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
is_dropout_ = p_drop > 0.0;
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
group_count_ = ck::type_convert<ck::index_t>(problem_desc_vec.size());
......@@ -860,6 +862,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
CElementwiseOperation c_element_op_;
float p_dropout_;
bool is_dropout_;
unsigned long long seed_;
unsigned long long offset_;
......@@ -900,7 +903,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm,
GroupKernelArg,
......@@ -909,7 +912,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_>;
has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel(
stream_config,
......@@ -933,11 +937,29 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// to concern Gemm0's loop
if(all_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
}
else if(!some_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, false>{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
else
{
......
......@@ -1230,6 +1230,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap,
typename C0MatrixMask,
typename VGradGridDescriptor_N_O,
......@@ -1957,38 +1958,44 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
// save z to global
if(p_z_grid)
if constexpr(IsDropout)
{
// P_dropped
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true,
decltype(n0),
decltype(i)>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
// save z to global
if(p_z_grid)
{
// P_dropped
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true,
decltype(n0),
decltype(i)>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
}
else
{
ignore = z_grid_buf;
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
else
{
ignore = z_grid_buf;
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
}
}
block_sync_lds(); // wait for gemm1 LDS read
......@@ -2176,9 +2183,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Gemm2::b_block_reset_copy_step); // rewind M
kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
......
......@@ -1140,6 +1140,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap,
typename C0MatrixMask,
typename VGradGridDescriptor_N_O,
......@@ -1852,38 +1853,44 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
// save z to global
if(p_z_grid)
if constexpr(IsDropout)
{
// P_dropped
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true,
decltype(n0),
decltype(i)>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
// save z to global
if(p_z_grid)
{
// P_dropped
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true,
decltype(n0),
decltype(i)>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
}
else
{
ignore = z_grid_buf;
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
else
{
ignore = z_grid_buf;
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
}
}
block_sync_lds(); // wait for gemm1 LDS read
......@@ -2126,9 +2133,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment