Commit c07c2b55 authored by danyao12's avatar danyao12
Browse files

continue fwd cleanup

parent dad06b35
...@@ -882,19 +882,19 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -882,19 +882,19 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n4)); // registerNum n4)); // registerNum
constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = // for blockwise copy constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
n0, // n0, // NRepeat
m1, // m1, // MWaveId
n1, // n1, // NWaveId
m2, // m0 1 m2, // MPerXdl
n2, // n0 4 n2, // NGroupNum
n3, // n1 1 n3, // NInputNum
n4, // m1 4 n4, // registerNum
I1)); // n2 1 I1)); // I1
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockId
m0, // MRepeat m0, // MRepeat
n0, // NRepeat n0, // NRepeat
m1, // MWaveId m1, // MWaveId
...@@ -907,26 +907,26 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -907,26 +907,26 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
constexpr auto z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto zM0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); constexpr auto ZM0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto zN0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); constexpr auto ZN0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto zM1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); constexpr auto ZM1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto zN1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); constexpr auto ZN1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto zM2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); constexpr auto ZM2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto zN2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); constexpr auto ZN2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto zN3 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); constexpr auto ZN3 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto zN4 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); constexpr auto ZN4 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = constexpr auto z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 =
transform_tensor_descriptor( transform_tensor_descriptor(
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_pass_through_transform(zM0), make_tuple(make_pass_through_transform(ZM0),
make_pass_through_transform(zN0), make_pass_through_transform(ZN0),
make_pass_through_transform(zM1), make_pass_through_transform(ZM1),
make_pass_through_transform(zN1), make_pass_through_transform(ZN1),
make_unmerge_transform(make_tuple(Number<zM2.value / zN4.value>{}, zN4)), make_unmerge_transform(make_tuple(ZM2 / ZN4, ZN4)),
make_pass_through_transform(zN2), make_pass_through_transform(ZN2),
make_pass_through_transform(zN3), make_pass_through_transform(ZN3),
make_pass_through_transform(zN4)), make_pass_through_transform(ZN4)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -955,7 +955,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -955,7 +955,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize()); p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
auto z_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto z_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ZDataType*>(p_shared), static_cast<ushort*>(p_shared),
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize()); z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx(); const auto wave_id = GetGemm0WaveIdx();
...@@ -974,7 +974,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -974,7 +974,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
m2, // MPerXdl m2, // MPerXdl
n2, // NGroupNum n2, // NGroupNum
n3, // NInputNum n3, // NInputNum
n4>, n4>, // registerNum
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim 7, // DstVectorDim
1, // DstScalarPerVector 1, // DstScalarPerVector
...@@ -1007,11 +1007,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1007,11 +1007,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
0, // nrepeat 0, // nrepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
wave_id[I1], // NWaveId wave_id[I1], // NWaveId
int(wave_m_n_id[I1] / 4), // MPerXdl wave_m_n_id[I1] / ZN4,
0, // group 0,
wave_m_n_id[I0], // NInputIndex wave_m_n_id[I0],
0, 0,
wave_m_n_id[I1] % 4)}; wave_m_n_id[I1] % ZN4)};
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, ushort,
...@@ -1091,8 +1091,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1091,8 +1091,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
false>; // SnakeCurved false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor( constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(zM0, zM1, zM2)), make_tuple(make_unmerge_transform(make_tuple(ZM0, ZM1, ZM2)),
make_unmerge_transform(make_tuple(zN0, zN1, zN2, zN3, zN4))), make_unmerge_transform(make_tuple(ZN0, ZN1, ZN2, ZN3, ZN4))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{})); make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment