Commit 9ca59788 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Review: Remove calls to dpp_gemm's MakeCDescriptor

parent b80b9e29
...@@ -160,7 +160,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2 ...@@ -160,7 +160,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
Number<MPerDpp>{}, Number<MPerDpp>{},
Number<NPerDpp>{})); Number<NPerDpp>{}));
return dpp_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2(c_block_desc_m0_n0_m1_n1_m2_n2); return c_block_desc_m0_n0_m1_n1_m2_n2;
} }
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2() __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2()
...@@ -173,8 +173,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2 ...@@ -173,8 +173,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
Number<NWaves>{}, Number<NWaves>{},
Number<MPerDpp>{}, Number<MPerDpp>{},
Number<NPerDpp>{})); Number<NPerDpp>{}));
return c_block_desc_g_m0_n0_m1_n1_m2_n2;
return dpp_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_N2(c_block_desc_g_m0_n0_m1_n1_m2_n2);
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
...@@ -191,7 +190,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2 ...@@ -191,7 +190,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return dpp_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); return c_grid_desc_m0_n0_m1_n1_m2_n2;
} }
template <typename CGridDesc_G_M_N> template <typename CGridDesc_G_M_N>
...@@ -210,7 +209,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2 ...@@ -210,7 +209,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return dpp_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_N2(c_grid_desc_g_m0_n0_m1_n1_m2_n2); return c_grid_desc_g_m0_n0_m1_n1_m2_n2;
} }
__host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
......
...@@ -199,72 +199,6 @@ struct DppGemm ...@@ -199,72 +199,6 @@ struct DppGemm
static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp."); static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp.");
} }
template <typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto
MakeCDescriptor_M0_N0_M1_N1_M2_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_pass_through_transform(Number<dpp_instr.m_per_wave>{}),
make_pass_through_transform(Number<dpp_instr.n_per_wave>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}));
}
template <typename CDesc_G_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto
MakeCDescriptor_G_M0_N0_M1_N1_M2_N2(const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
{
const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
return transform_tensor_descriptor(
c_desc_g_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(G),
make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_pass_through_transform(Number<dpp_instr.m_per_wave>{}),
make_pass_through_transform(Number<dpp_instr.n_per_wave>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{}));
}
__device__ static constexpr index_t GetRegSizePerDpp() __device__ static constexpr index_t GetRegSizePerDpp()
{ {
return MPerDpp * NPerDpp / dpp_instr.wave_size; return MPerDpp * NPerDpp / dpp_instr.wave_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment