Commit 9e16e38e authored by guangzlu's avatar guangzlu
Browse files

added dropout shuffle in lds for fwd

parent cd6e9903
...@@ -255,11 +255,16 @@ struct BlockwiseDropout ...@@ -255,11 +255,16 @@ struct BlockwiseDropout
return keep ? val * p_dropout_rescale : float(0); return keep ? val * p_dropout_rescale : float(0);
}; };
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) { static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) { static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = execute_dropout(z_thread_buf(offset) <= p_dropout_16bits, in_thread_buf(offset) = execute_dropout(z_thread_buf(offset) <= p_dropout_16bits,
in_thread_buf(offset)); in_thread_buf(offset));
tmp_index = tmp_index + 1;
// if(get_thread_global_1d_id()==0){
// printf("z at %d is %u \n", tmp_index, z_thread_buf(offset));
//}
}); });
}); });
} }
...@@ -306,48 +311,6 @@ struct BlockwiseDropout ...@@ -306,48 +311,6 @@ struct BlockwiseDropout
}); });
} }
template <typename ZThreadBuffer>
__host__ __device__ void GenerateZMatrix(ck::philox& ph,
index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf,
index_t MRaw)
{
// if(get_thread_global_1d_id() == 0){
// printf("MRepeat & KRepeat is %d , %d . \n", MRepeat, KRepeat);
// }
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw);
}
// ushort tmp_id[tmp_size];
// for(int i = 0; i < philox_calls; i++)
//{
// for(int j = 0; j < 4; j++)
// {
// tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
// }
//}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
}
ushort p_dropout_16bits; ushort p_dropout_16bits;
DataType p_dropout_rescale; DataType p_dropout_rescale;
}; };
......
...@@ -143,6 +143,23 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -143,6 +143,23 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{})); make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
} }
//// Z desc for source in blockwise copy
//__host__ __device__ static constexpr auto GetZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() ////=>
//for z use
//{
// //const auto M = z_grid_desc_m_n.GetLength(I0);
// //const auto N = z_grid_desc_m_n.GetLength(I1);
//
// constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
// constexpr auto N3 = mfma.num_groups_per_blk;
// constexpr auto N4 = mfma.num_input_blks;
// constexpr auto N5 = mfma.group_size;
// return make_naive_tensor_descriptor_packed(
// make_tuple(Number<MXdlPerWave>{}, Number<NXdlPerWave>{}, Number<Gemm0MWaves>{},
// Number<Gemm0NWaves>{},
// Number<MPerXdl>{}, Number<N3>{}, Number<N4>{}, Number<N5>{}));
//}
// C shuffle desc for source in gridwise copy // C shuffle desc for source in gridwise copy
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5( MakeCShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5(
...@@ -156,18 +173,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -156,18 +173,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
constexpr auto N4 = mfma.num_input_blks; constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size; constexpr auto N5 = mfma.group_size;
// printf("M / MPerBlock %d, ", M / MPerBlock);
// printf("MXdlPerWave %d, " , MXdlPerWave);
// printf("Gemm0MWaves %d, " , Gemm0MWaves);
// printf("MPerXdl / N5 %d, " , MPerXdl / N5);
// printf("N5 %d \n" , N5);
// printf("N / NPerBlock %d, " , N / NPerBlock);
// printf("NXdlPerWave %d, " , NXdlPerWave);
// printf("Gemm0NWaves %d, " , Gemm0NWaves);
// printf("N3 %d, " , N3);
// printf("N4 %d, " , N4);
// printf("N5 %d, " , N5);
return transform_tensor_descriptor( return transform_tensor_descriptor(
z_grid_desc_m_n, z_grid_desc_m_n,
make_tuple(make_unmerge_transform( make_tuple(make_unmerge_transform(
...@@ -175,9 +180,50 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -175,9 +180,50 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_unmerge_transform( make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))), make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6, 9>{}, Sequence<1, 3, 5, 7, 8, 10>{})); make_tuple(Sequence<0, 2, 4, 6, 9>{}, Sequence<1, 3, 5, 7, 8, 10>{})); // 0247,13568
} }
// using ZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4 = remove_cvref_t<decltype(
// GetZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>;
// Z shuffle desc for source in blockwise copy
//__host__ __device__ static constexpr auto
// GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4(const
// ZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4& z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4) ////=> for z
// use to shuffle
//{
// //const auto M = z_grid_desc_m_n.GetLength(I0);
// //const auto N = z_grid_desc_m_n.GetLength(I1);
//
// constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
// constexpr auto N3 = mfma.num_groups_per_blk;
// constexpr auto N4 = mfma.num_input_blks;
// constexpr auto N5 = mfma.group_size;
//
// constexpr auto z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4 =
// transform_tensor_descriptor(
// z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
// make_tuple(
// make_freeze_transform(Number<MXdlPerWave>{}),
// make_freeze_transform(Number<NXdlPerWave>{}),
// make_freeze_transform(Number<Gemm0MWaves>{}),
// make_freeze_transform(Number<Gemm0NWaves>{}),
// make_unmerge_transform(make_tuple(Number<MPerXdl / N5>{}, Number<N5>{})),
// make_freeze_transform(Number<N3>{}),
// make_freeze_transform(Number<N4>{}),
// make_freeze_transform(Number<N5>{})),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{},Sequence<4>{},
// Sequence<5>{}, Sequence<6>{}, Sequence<7>{}), make_tuple(Sequence<0>{}, Sequence<1>{},
// Sequence<2>{}, Sequence<3>{},Sequence<4,7>{}, Sequence<5>{}, Sequence<6>{},
// Sequence<8>{}));
// return z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4;
// //return make_naive_tensor_descriptor_packed(
// // make_tuple(Number<MXdlPerWave>{}, Number<NXdlPerWave>{}, Number<Gemm0MWaves>{},
// Number<Gemm0NWaves>{},
// // Number<MPerXdl / N5>{}, Number<N3>{}, Number<N4>{}, Number<N5>{},
// Number<N5>{}));
//}
__device__ static auto GetGemm0WaveIdx() __device__ static auto GetGemm0WaveIdx()
{ {
const index_t thread_id = get_thread_local_1d_id(); const index_t thread_id = get_thread_local_1d_id();
...@@ -904,6 +950,41 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -904,6 +950,41 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// printf("n4 is %d \n",n4.value); // printf("n4 is %d \n",n4.value);
//} //}
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4)); // registerNum
constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, //
n0, //
m1, //
n1, //
m2, // m0 1
n2, // n0 4
n3, // n1 1
n4, // m1 4
I1)); // n2 1
// constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5 = //for gridwise
// copy
// make_naive_tensor_descriptor_packed(make_tuple(I1,
// I1,
// m0, //
// n0, //
// m1, //
// n1, //
// m2, // m0 1
// n2, // n0 4
// n3, // n1 1
// n4, // m1 4
// I1)); // n2 1
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
...@@ -916,19 +997,50 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -916,19 +997,50 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n3, // NInputNum n3, // NInputNum
n4)); // registerNum n4)); // registerNum
constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5 = constexpr auto z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
I1, //
m0, // // constexpr auto z_block_lengths = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLengths();
n0, //
m1, // constexpr auto zM0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
n1, // constexpr auto zN0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
m2, // m0 constexpr auto zM1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
n2, // m1 constexpr auto zN1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
n3, // n0 constexpr auto zM2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
n4, // n1 constexpr auto zN2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
I1)); // n2 constexpr auto zN3 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
// ignore = z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5; constexpr auto zN4 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4 =
transform_tensor_descriptor(
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_pass_through_transform(zM0),
make_pass_through_transform(zN0),
make_pass_through_transform(zM1),
make_pass_through_transform(zN1),
make_unmerge_transform(make_tuple(Number<zM2.value / zN4.value>{}, zN4)),
make_pass_through_transform(zN2),
make_pass_through_transform(zN3),
make_pass_through_transform(zN4)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4, 7>{},
Sequence<5>{},
Sequence<6>{},
Sequence<8>{}));
// ignore = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4;
// ignore = z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4;
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short, unsigned short,
...@@ -939,7 +1051,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -939,7 +1051,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short, unsigned short,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5.GetElementSpaceSize(), z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4.GetElementSpaceSize(),
true> true>
z_tenor_buffer; // z buffer after shuffle z_tenor_buffer; // z buffer after shuffle
z_tenor_buffer.Clear(); z_tenor_buffer.Clear();
...@@ -948,10 +1060,17 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -948,10 +1060,17 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize()); p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
auto z_grid_tmp_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( // auto z_grid_tmp_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize()); // p_z_grid,
// z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5.GetElementSpaceSize());
ignore = z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5;
auto z_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ZDataType*>(p_shared),
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
// ignore = z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5; // ignore = z_block_buf;
// if(get_thread_global_1d_id()==0){ // if(get_thread_global_1d_id()==0){
// printf("z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize() is %ld \n", // printf("z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize() is %ld \n",
...@@ -990,63 +1109,111 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -990,63 +1109,111 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// \n",wave_m_n_id[I0], wave_m_n_id[I1]); // \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// } // }
//} //}
/*
auto z_tmp_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_tmp_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, ushort,
ZDataType, ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId Sequence<I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
n0, // NRepeat n0, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
n2, // NGroupNum n2, // NGroupNum
n3, // NInputNum n3, // NInputNum
n4>, n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim 9, // DstVectorDim
1, // DstScalarPerVector 1, // DstScalarPerVector
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId 0, // NBlockId
0, // mrepeat 0, // mrepeat
0, // nrepeat 0, // nrepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
wave_id[I1], // NWaveId wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl wave_m_n_id[I1], // MPerXdl
0, // group 0, // group
wave_m_n_id[I0], // NInputIndex wave_m_n_id[I0], // NInputIndex
0), 0),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
auto z_shuffle_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< auto z_shuffle_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
ZDataType,
ushort,
decltype(z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5),
decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5),
Sequence<I1, I1, m0, n0, m1, n1, m2, n2, n3, n4, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10>,
10,
1,
1,
true >{z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
make_multi_index(block_work_idx_m, //
MBlockId 0, // NBlockId 0, // mrepeat 0, //
nrepeat wave_id[I0], // MWaveId wave_id[I1], // NWaveId
int(wave_m_n_id[I1] / 4), //
MPerXdl 0, // group wave_m_n_id[I0], // NInputIndex 0,
wave_m_n_id[I1] % 4)};
*/
auto z_tmp_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough,
Sequence<m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
auto z_shuffle_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
ZDataType, ZDataType,
ushort, ushort,
decltype(z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5), decltype(z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4),
decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5), decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4),
Sequence<I1, I1, m0, n0, m1, n1, m2, n2, n3, n4, I1>, Sequence<m0, n0, m1, n1, m2, n2, n3, n4, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>,
10, 8,
1, 1,
1, 1,
true /* ResetCoordAfterRun */>{z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5, true>{z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4,
make_multi_index(block_work_idx_m, // MBlockId make_multi_index(0, // mrepeat
0, // NBlockId 0, // nrepeat
0, // mrepeat wave_id[I0], // MWaveId
0, // nrepeat wave_id[I1], // NWaveId
wave_id[I0], // MWaveId int(wave_m_n_id[I1] / 4), // MPerXdl
wave_id[I1], // NWaveId 0, // group
int(wave_m_n_id[I1] / 4), // MPerXdl wave_m_n_id[I0], // NInputIndex
0, // group 0,
wave_m_n_id[I0], // NInputIndex wave_m_n_id[I1] % 4)};
0,
wave_m_n_id[I1] % 4)};
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, ushort,
...@@ -1229,27 +1396,42 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1229,27 +1396,42 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tenor_tmp_buffer)>( blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tenor_tmp_buffer)>(
ph, global_elem_id, z_tenor_tmp_buffer); ph, global_elem_id, z_tenor_tmp_buffer);
z_tmp_thread_copy_vgpr_to_global.Run( z_tmp_thread_copy_vgpr_to_lds.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), z_tenor_tmp_buffer,
z_tenor_tmp_buffer, z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_block_buf);
z_grid_tmp_buf);
// z_tmp_thread_copy_vgpr_to_global.Run(
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_tmp_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_tmp_buf);
block_sync_lds(); block_sync_lds();
z_shuffle_thread_copy_global_to_vgpr.Run( // ignore = z_shuffle_thread_copy_lds_to_vgpr;
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
z_grid_tmp_buf, z_shuffle_thread_copy_lds_to_vgpr.Run(
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5, z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), z_block_buf,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer); z_tenor_buffer);
// z_shuffle_thread_copy_global_to_vgpr.Run(
// z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
// z_grid_tmp_buf,
// z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer);
blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf), blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf),
decltype(z_tenor_buffer), decltype(z_tenor_buffer),
false>(acc_thread_buf, false>(acc_thread_buf,
z_tenor_buffer); z_tenor_buffer);
// ignore = z_tenor_buffer;
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
...@@ -1259,13 +1441,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1259,13 +1441,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
block_sync_lds(); block_sync_lds();
z_tmp_thread_copy_vgpr_to_global.MoveDstSliceWindow( // z_tmp_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, // z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); // make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
z_shuffle_thread_copy_global_to_vgpr.MoveSrcSliceWindow( // z_shuffle_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5, // z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0)); // make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment