Commit af1059a3 authored by guangzlu's avatar guangzlu
Browse files

fixed bugs for lds shuffle

parent 957ab734
......@@ -255,7 +255,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end);
const index_t z_block_bytes_end =
SharedMemTrait::z_shuffle_block_space_size * sizeof(ZDataType);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
softmax_bytes_end,
c_block_bytes_end,
z_block_bytes_end);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
......@@ -415,6 +422,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
static constexpr auto z_shuffle_block_space_size = MPerBlock * NPerBlock;
};
template <bool HasMainKBlockLoop,
......@@ -874,7 +883,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n3, // NInputNum
n4)); // registerNum
constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4 = // for blockwise copy
constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, //
n0, //
m1, //
......@@ -911,7 +920,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
constexpr auto zN3 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto zN4 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4 =
constexpr auto z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 =
transform_tensor_descriptor(
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_pass_through_transform(zM0),
......@@ -940,21 +949,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence<8>{}));
// ignore = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4;
// ignore = z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4;
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
z_tenor_tmp_buffer;
z_tenor_tmp_buffer.Clear();
// ignore = z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4;
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4.GetElementSpaceSize(),
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize(),
true>
z_tenor_buffer; // z buffer after shuffle
z_tenor_buffer.Clear();
z_tensor_buffer; // z buffer after shuffle
z_tensor_buffer.Clear();
// z matrix global desc
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -964,6 +966,17 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
static_cast<ZDataType*>(p_shared),
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
// if(get_thread_global_1d_id()==0){
// printf("z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize() is %ld \n",
// z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize().value);
// printf("z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize() is %ld
// \n", z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize().value);
// printf("z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize() is %ld
// \n",z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize().value);
// printf("z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize() is %ld
// \n",z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize().value);
// }
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
......@@ -1028,14 +1041,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto z_shuffle_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
ZDataType,
ushort,
decltype(z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4),
decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4),
decltype(z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4),
decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4),
Sequence<m0, n0, m1, n1, m2, n2, n3, n4, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>,
8,
1,
1,
true>{z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4,
true>{z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
make_multi_index(0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
......@@ -1222,54 +1235,42 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// save z to global
if(p_z_grid)
{
blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tenor_tmp_buffer)>(
ph, global_elem_id, z_tenor_tmp_buffer);
blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tensor_buffer)>(
ph, global_elem_id, z_tensor_buffer);
z_tmp_thread_copy_vgpr_to_lds.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_tmp_buffer,
z_tensor_buffer,
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_block_buf);
block_sync_lds();
// ignore = z_shuffle_thread_copy_lds_to_vgpr;
z_shuffle_thread_copy_lds_to_vgpr.Run(
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4,
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
z_block_buf,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer);
z_tensor_buffer);
blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf),
decltype(z_tenor_buffer),
decltype(z_tensor_buffer),
false>(acc_thread_buf,
z_tenor_buffer);
z_tensor_buffer);
// ignore = z_tenor_buffer;
if(p_z_grid)
{
// ignore = z_tensor_buffer;
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
block_sync_lds();
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
else
{
// ignore = z_grid_buf;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnFwd<decltype(acc_thread_buf), false>(
acc_thread_buf, ph, global_elem_id);
}
}
// TODO: may convert to log domain
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment