Commit 957ab734 authored by guangzlu's avatar guangzlu
Browse files

removed global shuffle parameters

parent a3c14b5f
......@@ -143,87 +143,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
//// Z desc for source in blockwise copy
//__host__ __device__ static constexpr auto GetZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() ////=>
//for z use
//{
// //const auto M = z_grid_desc_m_n.GetLength(I0);
// //const auto N = z_grid_desc_m_n.GetLength(I1);
//
// constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
// constexpr auto N3 = mfma.num_groups_per_blk;
// constexpr auto N4 = mfma.num_input_blks;
// constexpr auto N5 = mfma.group_size;
// return make_naive_tensor_descriptor_packed(
// make_tuple(Number<MXdlPerWave>{}, Number<NXdlPerWave>{}, Number<Gemm0MWaves>{},
// Number<Gemm0NWaves>{},
// Number<MPerXdl>{}, Number<N3>{}, Number<N4>{}, Number<N5>{}));
//}
// C shuffle desc for source in gridwise copy
__host__ __device__ static constexpr auto
MakeCShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5(
const ZGridDesc_M_N& z_grid_desc_m_n) ////=> for z use to shuffle
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl / N5, N5)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6, 9>{}, Sequence<1, 3, 5, 7, 8, 10>{})); // 0247,13568
}
// using ZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4 = remove_cvref_t<decltype(
// GetZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>;
// Z shuffle desc for source in blockwise copy
//__host__ __device__ static constexpr auto
// GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4(const
// ZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4& z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4) ////=> for z
// use to shuffle
//{
// //const auto M = z_grid_desc_m_n.GetLength(I0);
// //const auto N = z_grid_desc_m_n.GetLength(I1);
//
// constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
// constexpr auto N3 = mfma.num_groups_per_blk;
// constexpr auto N4 = mfma.num_input_blks;
// constexpr auto N5 = mfma.group_size;
//
// constexpr auto z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4 =
// transform_tensor_descriptor(
// z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
// make_tuple(
// make_freeze_transform(Number<MXdlPerWave>{}),
// make_freeze_transform(Number<NXdlPerWave>{}),
// make_freeze_transform(Number<Gemm0MWaves>{}),
// make_freeze_transform(Number<Gemm0NWaves>{}),
// make_unmerge_transform(make_tuple(Number<MPerXdl / N5>{}, Number<N5>{})),
// make_freeze_transform(Number<N3>{}),
// make_freeze_transform(Number<N4>{}),
// make_freeze_transform(Number<N5>{})),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{},Sequence<4>{},
// Sequence<5>{}, Sequence<6>{}, Sequence<7>{}), make_tuple(Sequence<0>{}, Sequence<1>{},
// Sequence<2>{}, Sequence<3>{},Sequence<4,7>{}, Sequence<5>{}, Sequence<6>{},
// Sequence<8>{}));
// return z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4;
// //return make_naive_tensor_descriptor_packed(
// // make_tuple(Number<MXdlPerWave>{}, Number<NXdlPerWave>{}, Number<Gemm0MWaves>{},
// Number<Gemm0NWaves>{},
// // Number<MPerXdl / N5>{}, Number<N3>{}, Number<N4>{}, Number<N5>{},
// Number<N5>{}));
//}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
......@@ -462,9 +381,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ZGridDesc_M_N{}))>;
using ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5 = remove_cvref_t<decltype(
MakeCShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5(ZGridDesc_M_N{}))>;
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
......@@ -525,8 +441,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5&
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
const LSEGridDesc_M& lse_grid_desc_m,
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
......@@ -971,20 +885,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n4, // m1 4
I1)); // n2 1
// constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5 = //for gridwise
// copy
// make_naive_tensor_descriptor_packed(make_tuple(I1,
// I1,
// m0, //
// n0, //
// m1, //
// n1, //
// m2, // m0 1
// n2, // n0 4
// n3, // n1 1
// n4, // m1 4
// I1)); // n2 1
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
......@@ -1060,26 +960,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
// auto z_grid_tmp_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_z_grid,
// z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5.GetElementSpaceSize());
ignore = z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5;
auto z_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ZDataType*>(p_shared),
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
// ignore = z_block_buf;
// if(get_thread_global_1d_id()==0){
// printf("z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize() is %ld \n",
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
// printf("z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5.GetElementSpaceSize() is
// %ld \n", z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5.GetElementSpaceSize());
//
//}
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
......@@ -1109,59 +993,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
//}
/*
auto z_tmp_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
auto z_shuffle_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
ZDataType,
ushort,
decltype(z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5),
decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5),
Sequence<I1, I1, m0, n0, m1, n1, m2, n2, n3, n4, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10>,
10,
1,
1,
true >{z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
make_multi_index(block_work_idx_m, //
MBlockId 0, // NBlockId 0, // mrepeat 0, //
nrepeat wave_id[I0], // MWaveId wave_id[I1], // NWaveId
int(wave_m_n_id[I1] / 4), //
MPerXdl 0, // group wave_m_n_id[I0], // NInputIndex 0,
wave_m_n_id[I1] % 4)};
*/
auto z_tmp_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<ushort,
......@@ -1402,13 +1233,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_block_buf);
// z_tmp_thread_copy_vgpr_to_global.Run(
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_tmp_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_tmp_buf);
block_sync_lds();
// ignore = z_shuffle_thread_copy_lds_to_vgpr;
......@@ -1420,17 +1244,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer);
// z_shuffle_thread_copy_global_to_vgpr.Run(
// z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
// z_grid_tmp_buf,
// z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer);
blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf),
decltype(z_tenor_buffer),
false>(acc_thread_buf,
z_tenor_buffer);
// ignore = z_tenor_buffer;
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
......@@ -1441,14 +1259,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
block_sync_lds();
// z_tmp_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
// z_shuffle_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
// z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
// make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment