Commit 38f48480 authored by danyao12's avatar danyao12
Browse files

code cleanup for bwd dropout

parent 4388b767
......@@ -122,7 +122,7 @@ struct BlockwiseDropout
});
}
template <typename CThreadBuffer, bool using_sign_bit = false>
template <typename CThreadBuffer, typename Offset, bool using_sign_bit = false>
__host__ __device__ void ApplyDropoutAttnBwd(CThreadBuffer& in_thread_buf,
ck::philox& ph,
index_t element_global_1d_id,
......@@ -143,7 +143,7 @@ struct BlockwiseDropout
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw);
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{} * MRaw);
}
block_sync_lds();
......@@ -159,7 +159,10 @@ struct BlockwiseDropout
});
}
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
template <typename CThreadBuffer,
typename ZThreadBuffer,
typename Offset,
bool using_sign_bit = false>
__host__ __device__ void ApplyDropoutAttnBwdSaveZ(CThreadBuffer& in_thread_buf,
ck::philox& ph,
index_t element_global_1d_id,
......@@ -181,7 +184,7 @@ struct BlockwiseDropout
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw);
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{} * MRaw);
}
block_sync_lds();
......
......@@ -135,9 +135,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M5 = mfma.group_size;
return index_t(ceil(float(size) / M5) * M5);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
}
__device__ static auto GetGemm0WaveIdx()
......@@ -1613,21 +1613,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
n1, // NWaveId
m2, // MGroupNum
m3, // MInputNum
m4, // registerNum
m4, // RegisterNum
n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
ushort,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
// z matrix global desc
// ignore = p_z_tmp_grid;
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( // tmp buffer for shuffle
p_z_grid,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
......@@ -1656,14 +1653,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(num_gemm0_m_block_outer_loop - 1, // MBlockId
block_work_idx_n, // NBlockId
0, // mrepeat
0, // nrepeat
0, // MRepeat
0, // NRepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
0, // MPerXdl
wave_m_n_id[I0], // group
0, // NInputIndex
wave_m_n_id[I1]),
wave_m_n_id[I0], //
0, //
wave_m_n_id[I1]), // NPerXdl
tensor_operation::element_wise::PassThrough{}};
//
......@@ -1948,6 +1945,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
constexpr auto position_offset = M3 * M4;
// save z to global
if(p_z_grid)
{
......@@ -1962,10 +1960,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + int(global_elem_id_raw / M4) * M4;
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
......@@ -1989,12 +1988,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + int(global_elem_id_raw / M4) * M4;
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped
blockwise_dropout
.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
block_sync_lds(); // wait for gemm1 LDS read
......
......@@ -149,9 +149,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M5 = mfma.group_size;
return index_t(ceil(float(size) / M5) * M5);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
}
__device__ static auto GetGemm0WaveIdx()
......@@ -1574,21 +1574,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
n1, // NWaveId
m2, // MGroupNum
m3, // MInputNum
m4, // registerNum
m4, // RegisterNum
n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
ushort,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
// z matrix global desc
// ignore = p_z_tmp_grid;
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( // tmp buffer for shuffle
p_z_grid,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
......@@ -1617,14 +1614,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(num_gemm0_m_block_outer_loop - 1, // MBlockId
block_work_idx_n, // NBlockId
0, // mrepeat
0, // nrepeat
0, // MRepeat
0, // NRepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
0, // MPerXdl
wave_m_n_id[I0], // group
0, // NInputIndex
wave_m_n_id[I1]),
wave_m_n_id[I0], //
0, //
wave_m_n_id[I1]), // NPerXdl
tensor_operation::element_wise::PassThrough{}};
//
......@@ -1864,6 +1861,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
constexpr auto position_offset = M3 * M4;
// save z to global
if(p_z_grid)
{
......@@ -1878,10 +1876,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + int(global_elem_id_raw / M4) * M4;
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
......@@ -1905,11 +1904,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + int(global_elem_id_raw / M4) * M4;
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped
blockwise_dropout
.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
block_sync_lds(); // wait for gemm1 LDS read
......
......@@ -1989,43 +1989,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
// }
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
// }
auto global_elem_id_raw =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % 4) * MRaw + int(global_elem_id_raw / 4) * 4;
// if(get_block_1d_id() == 0 && get_thread_local_1d_id()==64){
// printf("global_elem_id is %d \n", global_elem_id);
//}
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// if(get_thread_global_1d_id() == 0){
// printf("Acc0TileIterator::GetNumOfAccess() is %d \n",
// Acc0TileIterator::GetNumOfAccess()); printf("n0.value is %d \n", n0.value);
// printf("id_step is %d \n", id_step);
//}
// dropout
// z_tenor_buffer_tmp -> z_grid_buf_tmp -> shuffle -> z_tenor_buffer -> z_grid_buf
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true>(
......@@ -2036,29 +2011,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
//// P_dropped
// static_for<0, n0, 1>{}([&](auto i) {
// blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
// decltype(z_tenor_buffer),
// true,
// decltype(n0),
// decltype(i)>(
// s_slash_p_thread_buf, ph, z_tenor_buffer);
//
// z_thread_copy_vgpr_to_global.Run(
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_buf);
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
//});
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
}
else
{
......@@ -2094,22 +2046,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
// }
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
// }
auto global_elem_id_raw =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
......
......@@ -1925,43 +1925,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
// }
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
// }
auto global_elem_id_raw =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % 4) * MRaw + int(global_elem_id_raw / 4) * 4;
// if(get_block_1d_id() == 0 && get_thread_local_1d_id()==64){
// printf("global_elem_id is %d \n", global_elem_id);
//}
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// if(get_thread_global_1d_id() == 0){
// printf("Acc0TileIterator::GetNumOfAccess() is %d \n",
// Acc0TileIterator::GetNumOfAccess()); printf("n0.value is %d \n", n0.value);
// printf("id_step is %d \n", id_step);
//}
// dropout
// z_tenor_buffer_tmp -> z_grid_buf_tmp -> shuffle -> z_tenor_buffer -> z_grid_buf
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true>(
......@@ -2007,22 +1982,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
// }
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
// }
auto global_elem_id_raw =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment