"vscode:/vscode.git/clone" did not exist on "a7a6ba8e8e1be25edd6c6e6605dc6aa8cbc56e88"
Commit 34b1c320 authored by ltqin's avatar ltqin
Browse files

format

parent b5a3ea2d
......@@ -156,7 +156,8 @@ __global__ void
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(p_a_grid + a_batch_offset,
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
......@@ -994,10 +995,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
arg.m_raw_padded_,
arg.n_raw_padded_);
};
if(arg.p_drop_ > 0.0){
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{});
}else{
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{});
if(arg.p_drop_ > 0.0)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
return ave_time;
}
......
......@@ -155,7 +155,8 @@ __global__ void
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(p_a_grid + a_batch_offset,
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
......@@ -1018,16 +1019,20 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
if(arg.p_drop_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{});
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, false>{});
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else
{
if(arg.p_drop_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{});
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{});
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
return ave_time;
......
......@@ -993,16 +993,20 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
if(all_has_main_k_block_loop)
{
if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{});
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, false>{});
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!some_has_main_k_block_loop)
{
if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{});
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{});
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
else
{
......
......@@ -1000,16 +1000,20 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
if(all_has_main_k_block_loop)
{
if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{});
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, false>{});
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!some_has_main_k_block_loop)
{
if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{});
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{});
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
else
{
......
......@@ -1948,13 +1948,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
constexpr auto position_offset = M3 * M4;
// save z to global
if constexpr(IsDropout){
if constexpr(IsDropout)
{
if(p_z_grid)
{
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
......@@ -1964,13 +1967,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
blockwise_dropout
.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
......@@ -1981,8 +1986,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ignore = z_grid_buf;
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
......
......@@ -1864,13 +1864,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
constexpr auto position_offset = M3 * M4;
// save z to global
if constexpr(IsDropout){
if constexpr(IsDropout)
{
if(p_z_grid)
{
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
......@@ -1880,13 +1883,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
blockwise_dropout
.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
......@@ -1897,8 +1902,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ignore = z_grid_buf;
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
......
......@@ -28,10 +28,7 @@ struct ReferenceSoftmax : public device::BaseOperator
double beta,
const std::vector<index_t> sm_reduce_dims,
Tensor<AccDataType>* sm_stats_ptr = nullptr)
: in_(in),
out_(out),
sm_reduce_dims_(sm_reduce_dims),
sm_stats_ptr_(sm_stats_ptr)
: in_(in), out_(out), sm_reduce_dims_(sm_reduce_dims), sm_stats_ptr_(sm_stats_ptr)
{
alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<AccDataType>(beta);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment