Commit af1059a3 authored by guangzlu's avatar guangzlu
Browse files

fixed bugs for lds shuffle

parent 957ab734
...@@ -255,7 +255,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -255,7 +255,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const index_t c_block_bytes_end = const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); const index_t z_block_bytes_end =
SharedMemTrait::z_shuffle_block_space_size * sizeof(ZDataType);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
softmax_bytes_end,
c_block_bytes_end,
z_block_bytes_end);
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...@@ -415,6 +422,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -415,6 +422,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size = static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
static constexpr auto z_shuffle_block_space_size = MPerBlock * NPerBlock;
}; };
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
...@@ -874,7 +883,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -874,7 +883,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n3, // NInputNum n3, // NInputNum
n4)); // registerNum n4)); // registerNum
constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4 = // for blockwise copy constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // make_naive_tensor_descriptor_packed(make_tuple(m0, //
n0, // n0, //
m1, // m1, //
...@@ -911,7 +920,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -911,7 +920,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
constexpr auto zN3 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); constexpr auto zN3 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto zN4 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); constexpr auto zN4 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4 = constexpr auto z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 =
transform_tensor_descriptor( transform_tensor_descriptor(
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_pass_through_transform(zM0), make_tuple(make_pass_through_transform(zM0),
...@@ -940,21 +949,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -940,21 +949,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence<8>{})); Sequence<8>{}));
// ignore = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4; // ignore = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4;
// ignore = z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4; // ignore = z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4;
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
z_tenor_tmp_buffer;
z_tenor_tmp_buffer.Clear();
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short, unsigned short,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4.GetElementSpaceSize(), z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize(),
true> true>
z_tenor_buffer; // z buffer after shuffle z_tensor_buffer; // z buffer after shuffle
z_tenor_buffer.Clear(); z_tensor_buffer.Clear();
// z matrix global desc // z matrix global desc
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -964,6 +966,17 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -964,6 +966,17 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
static_cast<ZDataType*>(p_shared), static_cast<ZDataType*>(p_shared),
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize()); z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
// if(get_thread_global_1d_id()==0){
// printf("z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize() is %ld \n",
// z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize().value);
// printf("z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize() is %ld
// \n", z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize().value);
// printf("z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize() is %ld
// \n",z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize().value);
// printf("z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize() is %ld
// \n",z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize().value);
// }
const auto wave_id = GetGemm0WaveIdx(); const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
...@@ -1028,14 +1041,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1028,14 +1041,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto z_shuffle_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< auto z_shuffle_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
ZDataType, ZDataType,
ushort, ushort,
decltype(z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4), decltype(z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4),
decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4), decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4),
Sequence<m0, n0, m1, n1, m2, n2, n3, n4, I1>, Sequence<m0, n0, m1, n1, m2, n2, n3, n4, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>,
8, 8,
1, 1,
1, 1,
true>{z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4, true>{z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
make_multi_index(0, // mrepeat make_multi_index(0, // mrepeat
0, // nrepeat 0, // nrepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
...@@ -1222,54 +1235,42 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1222,54 +1235,42 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value; // index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// save z to global // save z to global
if(p_z_grid)
{
blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tenor_tmp_buffer)>(
ph, global_elem_id, z_tenor_tmp_buffer);
z_tmp_thread_copy_vgpr_to_lds.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_tmp_buffer,
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_block_buf);
block_sync_lds(); blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tensor_buffer)>(
ph, global_elem_id, z_tensor_buffer);
// ignore = z_shuffle_thread_copy_lds_to_vgpr; z_tmp_thread_copy_vgpr_to_lds.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer,
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_block_buf);
z_shuffle_thread_copy_lds_to_vgpr.Run( z_shuffle_thread_copy_lds_to_vgpr.Run(
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4, z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
z_block_buf, z_block_buf,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4, z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer); z_tensor_buffer);
blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf), blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf),
decltype(z_tenor_buffer), decltype(z_tensor_buffer),
false>(acc_thread_buf, false>(acc_thread_buf,
z_tenor_buffer); z_tensor_buffer);
// ignore = z_tenor_buffer; if(p_z_grid)
{
// ignore = z_tensor_buffer;
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer, z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf); z_grid_buf);
block_sync_lds();
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} }
else
{
// ignore = z_grid_buf;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnFwd<decltype(acc_thread_buf), false>(
acc_thread_buf, ph, global_elem_id);
}
} }
// TODO: may convert to log domain // TODO: may convert to log domain
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment