Commit 5765ba51 authored by coderfeli's avatar coderfeli
Browse files

auto calculate hard code params

parent 3f9dbcac
......@@ -124,21 +124,24 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock =
CDEShuffleBlockTransferScalarPerVectors{}[I0];
// K1 should be Number<...>
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto BlockSizeNumber = Number<BlockSize>{};
static constexpr index_t NLane = 32;
static constexpr index_t NWave = 4;
static constexpr index_t KLane = 2;
static constexpr index_t KRepeat = 8;
static_assert(NLane * NWave * KLane == BlockSize);
static constexpr index_t NumDTensor = DsDataType::Size();
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
static constexpr index_t KPack = math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
static constexpr index_t KLane = mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static_assert(NLane * NWave * KLane == BlockSize);
static_assert(NXdlPerWave == 1, "only 1 validated now, tbd next week");
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
......@@ -152,10 +155,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
using DsGridPointer = decltype(MakeDsGridPointer());
static constexpr index_t KPack = math::max(
math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
......@@ -321,11 +320,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NKSWIZZLE_V = BlockSize * KPack;
constexpr index_t NKSWIZZLE_N = Number<NKSWIZZLE_V>{};
constexpr index_t NkSwizzle = BlockSize * KPack;
constexpr index_t NkSwizzleNumber = Number<NkSwizzle>{};
return make_naive_tensor_descriptor(
make_tuple(N0, K0, NKSWIZZLE_N),
make_tuple(K0 * NKSWIZZLE_V, NKSWIZZLE_N, I1)
make_tuple(N0, K0, NkSwizzleNumber),
make_tuple(K0 * NkSwizzle, NkSwizzleNumber, I1)
);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment