Commit c07c2b55 authored by danyao12's avatar danyao12
Browse files

continue fwd cleanup

parent dad06b35
......@@ -882,19 +882,19 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n4)); // registerNum
constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, //
n0, //
m1, //
n1, //
m2, // m0 1
n2, // n0 4
n3, // n1 1
n4, // m1 4
I1)); // n2 1
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4, // registerNum
I1)); // I1
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
I1, // NBlockId
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
......@@ -907,26 +907,26 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
constexpr auto z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto zM0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto zN0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto zM1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto zN1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto zM2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto zN2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto zN3 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto zN4 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto ZM0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto ZN0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto ZM1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto ZN1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto ZM2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto ZN2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto ZN3 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto ZN4 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 =
transform_tensor_descriptor(
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_pass_through_transform(zM0),
make_pass_through_transform(zN0),
make_pass_through_transform(zM1),
make_pass_through_transform(zN1),
make_unmerge_transform(make_tuple(Number<zM2.value / zN4.value>{}, zN4)),
make_pass_through_transform(zN2),
make_pass_through_transform(zN3),
make_pass_through_transform(zN4)),
make_tuple(make_pass_through_transform(ZM0),
make_pass_through_transform(ZN0),
make_pass_through_transform(ZM1),
make_pass_through_transform(ZN1),
make_unmerge_transform(make_tuple(ZM2 / ZN4, ZN4)),
make_pass_through_transform(ZN2),
make_pass_through_transform(ZN3),
make_pass_through_transform(ZN4)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -955,7 +955,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
auto z_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ZDataType*>(p_shared),
static_cast<ushort*>(p_shared),
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
......@@ -967,14 +967,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough,
Sequence<m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>, // registerNum
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
......@@ -1003,15 +1003,15 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
1,
1,
true>{z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
make_multi_index(0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
int(wave_m_n_id[I1] / 4), // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
make_multi_index(0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1] / ZN4,
0,
wave_m_n_id[I1] % 4)};
wave_m_n_id[I0],
0,
wave_m_n_id[I1] % ZN4)};
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
......@@ -1091,8 +1091,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(zM0, zM1, zM2)),
make_unmerge_transform(make_tuple(zN0, zN1, zN2, zN3, zN4))),
make_tuple(make_unmerge_transform(make_tuple(ZM0, ZM1, ZM2)),
make_unmerge_transform(make_tuple(ZN0, ZN1, ZN2, ZN3, ZN4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment