Commit dad06b35 authored by danyao12's avatar danyao12
Browse files

code cleanup for fwd dropout

parent 51e102e5
...@@ -122,8 +122,7 @@ struct BlockwiseDropout ...@@ -122,8 +122,7 @@ struct BlockwiseDropout
}); });
} }
template <typename CThreadBuffer, bool using_sign_bit = false>
template <typename CThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropoutAttnBwd(CThreadBuffer& in_thread_buf, __host__ __device__ void ApplyDropoutAttnBwd(CThreadBuffer& in_thread_buf,
ck::philox& ph, ck::philox& ph,
index_t element_global_1d_id, index_t element_global_1d_id,
...@@ -185,15 +184,6 @@ struct BlockwiseDropout ...@@ -185,15 +184,6 @@ struct BlockwiseDropout
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw); ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw);
} }
// ushort tmp_id[tmp_size];
// for(int i = 0; i < philox_calls; i++)
//{
// for(int j = 0; j < 4; j++)
// {
// tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
// }
//}
block_sync_lds(); block_sync_lds();
int tmp_index = 0; int tmp_index = 0;
...@@ -227,9 +217,6 @@ struct BlockwiseDropout ...@@ -227,9 +217,6 @@ struct BlockwiseDropout
in_thread_buf(offset) = execute_dropout(z_thread_buf(offset) <= p_dropout_16bits, in_thread_buf(offset) = execute_dropout(z_thread_buf(offset) <= p_dropout_16bits,
in_thread_buf(offset)); in_thread_buf(offset));
tmp_index = tmp_index + 1; tmp_index = tmp_index + 1;
// if(get_thread_global_1d_id()==0){
// printf("z at %d is %u \n", tmp_index, z_thread_buf(offset));
//}
}); });
}); });
} }
...@@ -240,11 +227,6 @@ struct BlockwiseDropout ...@@ -240,11 +227,6 @@ struct BlockwiseDropout
index_t element_global_1d_id, index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf) ZThreadBuffer& z_thread_buf)
{ {
// if(get_thread_global_1d_id() == 0){
// printf("MRepeat & KRepeat is %d , %d . \n", MRepeat, KRepeat);
// }
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4; int philox_calls = tmp_size / 4;
...@@ -255,15 +237,6 @@ struct BlockwiseDropout ...@@ -255,15 +237,6 @@ struct BlockwiseDropout
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8); ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
} }
// ushort tmp_id[tmp_size];
// for(int i = 0; i < philox_calls; i++)
//{
// for(int j = 0; j < 4; j++)
// {
// tmp_id[i * 4 + j] = element_global_1d_id + i * 8;
// }
//}
block_sync_lds(); block_sync_lds();
int tmp_index = 0; int tmp_index = 0;
......
...@@ -145,9 +145,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -145,9 +145,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size) __host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{ {
constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma; constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N5 = mfma.group_size; constexpr auto group_size = mfma.group_size;
return index_t(ceil(float(size) / N5) * N5); return math::integer_divide_ceil(size, group_size) * group_size;
} }
__device__ static auto GetGemm0WaveIdx() __device__ static auto GetGemm0WaveIdx()
...@@ -263,7 +263,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -263,7 +263,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
const index_t z_block_bytes_end = const index_t z_block_bytes_end =
SharedMemTrait::z_shuffle_block_space_size * sizeof(ZDataType); SharedMemTrait::z_shuffle_block_space_size * sizeof(ushort);
return math::max(gemm0_bytes_end, return math::max(gemm0_bytes_end,
gemm1_bytes_end, gemm1_bytes_end,
...@@ -871,14 +871,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -871,14 +871,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// z vgpr copy to global // z vgpr copy to global
// //
// z matrix threadwise desc // z matrix threadwise desc
// if(get_thread_global_1d_id()==0){
// printf("m2 is %d \n",m2.value);
// printf("n2 is %d \n",n2.value);
// printf("n3 is %d \n",n3.value);
// printf("n4 is %d \n",n4.value);
//}
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = // for blockwise copy constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
n0, // NRepeat n0, // NRepeat
...@@ -915,8 +907,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -915,8 +907,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
constexpr auto z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// constexpr auto z_block_lengths = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLengths();
constexpr auto zM0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); constexpr auto zM0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto zN0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); constexpr auto zN0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto zM1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); constexpr auto zM1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
...@@ -954,17 +944,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -954,17 +944,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence<6>{}, Sequence<6>{},
Sequence<8>{})); Sequence<8>{}));
// ignore = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4;
// ignore = z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4;
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short, ushort,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize(), z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize(),
true> true>
z_tensor_buffer; // z buffer after shuffle z_tensor_buffer;
z_tensor_buffer.Clear(); z_tensor_buffer.Clear();
// z matrix global desc
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize()); p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
...@@ -972,50 +958,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -972,50 +958,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
static_cast<ZDataType*>(p_shared), static_cast<ZDataType*>(p_shared),
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize()); z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
// if(get_thread_global_1d_id()==0){
// printf("z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize() is %ld \n",
// z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize().value);
// printf("z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize() is %ld
// \n", z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize().value);
// printf("z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize() is %ld
// \n",z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize().value);
// printf("z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize() is %ld
// \n",z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize().value);
// }
const auto wave_id = GetGemm0WaveIdx(); const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
// if(get_block_1d_id()==0){
// if(get_thread_local_1d_id()==0){
// printf("tid = 0 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==1){
// printf("tid = 1 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==2){
// printf("tid = 2 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==3){
// printf("tid = 3 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==32){
// printf("tid = 32 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==64){
// printf("tid = 32 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
//}
auto z_tmp_thread_copy_vgpr_to_lds = auto z_tmp_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<ushort, ThreadwiseTensorSliceTransfer_v1r3<ushort,
ZDataType, ushort,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -1045,7 +993,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1045,7 +993,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
auto z_shuffle_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< auto z_shuffle_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
ZDataType, ushort,
ushort, ushort,
decltype(z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4), decltype(z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4),
decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4), decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4),
...@@ -1111,10 +1059,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1111,10 +1059,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
if(c0_matrix_mask.IsTileSkippable( if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock)) m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{ {
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
continue; continue;
} }
// gemm0 // gemm0
...@@ -1199,10 +1143,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1199,10 +1143,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto global_elem_id = z_random_matrix_offset + m_global * raw_n_padded + auto global_elem_id = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id n_global; // unique element global 1d id
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// save z to global
blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tensor_buffer)>( blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tensor_buffer)>(
ph, global_elem_id, z_tensor_buffer); ph, global_elem_id, z_tensor_buffer);
...@@ -1224,6 +1164,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1224,6 +1164,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
false>(acc_thread_buf, false>(acc_thread_buf,
z_tensor_buffer); z_tensor_buffer);
// save z to global
if(p_z_grid) if(p_z_grid)
{ {
// ignore = z_tensor_buffer; // ignore = z_tensor_buffer;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment