Commit 52478ac3 authored by ltqin's avatar ltqin
Browse files

change name to d_thread_desc_mblock_m1

parent 2aa9cbee
...@@ -156,36 +156,37 @@ __global__ void ...@@ -156,36 +156,37 @@ __global__ void
} }
else else
{ {
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_b_grid + b_batch_offset, p_a_grid + a_batch_offset,
z_matrix_ptr, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, z_matrix_ptr,
p_c_grid + c_batch_offset, p_b1_grid + b1_batch_offset,
p_lse_grid + lse_batch_offset, p_c_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset, p_lse_grid + lse_batch_offset,
p_qgrad_grid + a_batch_offset, p_ygrad_grid + c_batch_offset,
p_kgrad_grid + b_batch_offset, p_qgrad_grid + a_batch_offset,
p_vgrad_grid + b1_batch_offset, p_kgrad_grid + b_batch_offset,
p_shared, p_vgrad_grid + b1_batch_offset,
a_element_op, p_shared,
b_element_op, a_element_op,
acc_element_op, b_element_op,
b1_element_op, acc_element_op,
c_element_op, b1_element_op,
a_grid_desc_ak0_m_ak1, c_element_op,
b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
c_grid_desc_mblock_mperblock_nblock_nperblock, b1_grid_desc_bk0_n_bk1,
lse_grid_desc_m, c_grid_desc_mblock_mperblock_nblock_nperblock,
ygrad_grid_desc_o0_m_o1, lse_grid_desc_m,
block_2_ctile_map, ygrad_grid_desc_o0_m_o1,
c0_matrix_mask, block_2_ctile_map,
p_drop, c0_matrix_mask,
ph, p_drop,
z_random_matrix_offset, ph,
raw_n_padded, z_random_matrix_offset,
0); raw_n_padded,
0);
} }
#else #else
ignore = p_a_grid; ignore = p_a_grid;
...@@ -1000,10 +1001,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1000,10 +1001,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
arg.m_raw_padded_, arg.m_raw_padded_,
arg.n_raw_padded_); arg.n_raw_padded_);
}; };
if(arg.p_drop_ > 0.0){ if(arg.p_drop_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{}); {
}else{ ave_time = launch_kernel(integral_constant<bool, false>{},
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{}); integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
} }
return ave_time; return ave_time;
} }
......
...@@ -155,36 +155,37 @@ __global__ void ...@@ -155,36 +155,37 @@ __global__ void
} }
else else
{ {
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_b_grid + b_batch_offset, p_a_grid + a_batch_offset,
z_matrix_ptr, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, z_matrix_ptr,
p_c_grid + c_batch_offset, p_b1_grid + b1_batch_offset,
p_lse_grid + lse_batch_offset, p_c_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset, p_lse_grid + lse_batch_offset,
p_qgrad_grid + a_batch_offset, p_ygrad_grid + c_batch_offset,
p_kgrad_grid + b_batch_offset, p_qgrad_grid + a_batch_offset,
p_vgrad_grid + b1_batch_offset, p_kgrad_grid + b_batch_offset,
p_shared, p_vgrad_grid + b1_batch_offset,
a_element_op, p_shared,
b_element_op, a_element_op,
acc_element_op, b_element_op,
b1_element_op, acc_element_op,
c_element_op, b1_element_op,
a_grid_desc_ak0_m_ak1, c_element_op,
b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
c_grid_desc_mblock_mperblock_nblock_nperblock, b1_grid_desc_bk0_n_bk1,
lse_grid_desc_m, c_grid_desc_mblock_mperblock_nblock_nperblock,
ygrad_grid_desc_m0_o_m1, lse_grid_desc_m,
block_2_ctile_map, ygrad_grid_desc_m0_o_m1,
c0_matrix_mask, block_2_ctile_map,
p_drop, c0_matrix_mask,
ph, p_drop,
z_random_matrix_offset, ph,
raw_n_padded, z_random_matrix_offset,
0); raw_n_padded,
0);
} }
#else #else
ignore = p_a_grid; ignore = p_a_grid;
...@@ -1024,16 +1025,20 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1024,16 +1025,20 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
if(arg.p_drop_ > 0.0) if(arg.p_drop_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{}); ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
else else
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, false>{}); ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
} }
else else
{ {
if(arg.p_drop_ > 0.0) if(arg.p_drop_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{}); ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
else else
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{}); ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
} }
return ave_time; return ave_time;
......
...@@ -999,16 +999,20 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -999,16 +999,20 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
if(all_has_main_k_block_loop) if(all_has_main_k_block_loop)
{ {
if(arg.p_dropout_ > 0.0) if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{}); ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
else else
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, false>{}); ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
} }
else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
{ {
if(arg.p_dropout_ > 0.0) if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{}); ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
else else
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{}); ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
} }
else else
{ {
......
...@@ -1006,16 +1006,20 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1006,16 +1006,20 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
if(all_has_main_k_block_loop) if(all_has_main_k_block_loop)
{ {
if(arg.p_dropout_ > 0.0) if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{}); ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
else else
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, false>{}); ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
} }
else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
{ {
if(arg.p_dropout_ > 0.0) if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{}); ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
else else
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{}); ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
} }
else else
{ {
......
...@@ -1935,8 +1935,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1 ...@@ -1935,8 +1935,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
if(c0_matrix_mask.IsMaskedElement(m_global, n_global)) if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
{ {
s_slash_p_thread_buf(i) = -ck::NumericLimits<float>::Infinity(); s_slash_p_thread_buf(i) = -ck::NumericLimits<float>::Infinity();
......
...@@ -1948,54 +1948,61 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1948,54 +1948,61 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
constexpr auto position_offset = M3 * M4; constexpr auto position_offset = M3 * M4;
// save z to global // save z to global
if constexpr(IsDropout){ if constexpr(IsDropout)
{
if(p_z_grid) if(p_z_grid)
{ {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin; auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; auto m_local =
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id n_global; // unique element global 1d id
auto global_elem_id = auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4; (global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf), blockwise_dropout
decltype(z_tenor_buffer), .template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(position_offset), decltype(z_tenor_buffer),
true>( decltype(position_offset),
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded); true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), z_thread_copy_vgpr_to_global.Run(
z_tenor_buffer, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_grid_buf); z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
} }
else else
{ {
ignore = z_grid_buf; ignore = z_grid_buf;
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin; auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; auto m_local =
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id n_global; // unique element global 1d id
auto global_elem_id = auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4; (global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset), decltype(position_offset),
true>( true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded); s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
} }
} }
......
...@@ -1864,53 +1864,60 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1864,53 +1864,60 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
constexpr auto position_offset = M3 * M4; constexpr auto position_offset = M3 * M4;
// save z to global // save z to global
if constexpr(IsDropout){ if constexpr(IsDropout)
{
if(p_z_grid) if(p_z_grid)
{ {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin; auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; auto m_local =
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id n_global; // unique element global 1d id
auto global_elem_id = auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4; (global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf), blockwise_dropout
decltype(z_tenor_buffer), .template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(position_offset), decltype(z_tenor_buffer),
true>( decltype(position_offset),
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded); true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), z_thread_copy_vgpr_to_global.Run(
z_tenor_buffer, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_grid_buf); z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
} }
else else
{ {
ignore = z_grid_buf; ignore = z_grid_buf;
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin; auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; auto m_local =
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id n_global; // unique element global 1d id
auto global_elem_id = auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4; (global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset), decltype(position_offset),
true>( true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded); s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
} }
} }
......
...@@ -165,7 +165,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -165,7 +165,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const index_t block_work_idx_m = block_work_idx[I0]; const index_t block_work_idx_m = block_work_idx[I0];
constexpr auto d_thread_desc_mblock_mrepeat_mwave_mperxdl = constexpr auto d_thread_desc_mblock_m1 =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1));
constexpr auto y_thread_desc_m0_m1_n0_n1 = make_naive_tensor_descriptor_packed(make_tuple( constexpr auto y_thread_desc_m0_m1_n0_n1 = make_naive_tensor_descriptor_packed(make_tuple(
...@@ -244,7 +244,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -244,7 +244,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
auto d_thread_copy_vgpr_to_global = auto d_thread_copy_vgpr_to_global =
ThreadwiseTensorSliceTransfer_v1r3<FloatD, ThreadwiseTensorSliceTransfer_v1r3<FloatD,
FloatD, FloatD,
decltype(d_thread_desc_mblock_mrepeat_mwave_mperxdl), decltype(d_thread_desc_mblock_m1),
decltype(d_grid_desc_mblock_mperblock), decltype(d_grid_desc_mblock_mperblock),
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Sequence<1, 1>, Sequence<1, 1>,
...@@ -260,7 +260,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -260,7 +260,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// copy from VGPR to Global // copy from VGPR to Global
d_thread_copy_vgpr_to_global.Run(d_thread_desc_mblock_mrepeat_mwave_mperxdl, d_thread_copy_vgpr_to_global.Run(d_thread_desc_mblock_m1,
make_tuple(I0, I0), make_tuple(I0, I0),
y_dot_ygrad_thread_accum_buf, y_dot_ygrad_thread_accum_buf,
d_grid_desc_mblock_mperblock, d_grid_desc_mblock_mperblock,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment