Commit bed6f33c authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Move argument field computing logic into device op side

parent 2a43fc3b
......@@ -209,7 +209,20 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t StrideB,
index_t StrideC)
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K),
GridwiseGemm::CalculateAK0(K),
GridwiseGemm::CalculateBK0(K)};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -236,7 +249,12 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
K,
StrideA,
StrideB,
StrideC);
StrideC,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K),
GridwiseGemm::CalculateAK0(K),
GridwiseGemm::CalculateBK0(K));
}
// polymorphic
......
......@@ -91,10 +91,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK0_ = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0_ = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1_ = Number<AK1Value>{};
static constexpr auto BK1_ = Number<BK1Value>{};
static constexpr auto AK0_c = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0_c = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1_c = Number<AK1Value>{};
static constexpr auto BK1_c = Number<BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......@@ -398,7 +398,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
index_t StrideC_,
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t AK0_,
index_t BK0_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
......@@ -408,17 +413,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
KPadded{CalculateKPadded(K_)},
AK0{CalculateAK0(K_)},
BK0{CalculateBK0(K_)},
a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(
M_, CalculateMPadded(M_), K_, CalculateKPadded(K_), StrideA_, CalculateAK0(K_))},
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(
K_, CalculateKPadded(K_), N_, CalculateNPadded(N_), StrideB_, CalculateBK0(K_))},
c_grid_desc_m_n{MakeCGridDescriptor_M_N(
M_, CalculateMPadded(M_), N_, CalculateNPadded(N_), StrideC_)}
MPadded{MPadded_},
NPadded{NPadded_},
KPadded{KPadded_},
AK0{AK0_},
BK0{BK0_},
a_grid_desc_ak0_m_ak1{
MakeAGridDescriptor_AK0_M_AK1(M_, MPadded_, K_, KPadded_, StrideA_, AK0_)},
b_grid_desc_bk0_n_bk1{
MakeBGridDescriptor_BK0_N_BK1(K_, KPadded_, N_, NPadded_, StrideB_, BK0_)},
c_grid_desc_m_n{MakeCGridDescriptor_M_N(M_, MPadded_, N_, NPadded_, StrideC_)}
{
}
......@@ -470,16 +474,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0_, Number<MPerBlock>{}, AK1_),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1_, AK1_, I1));
make_tuple(AK0_c, Number<MPerBlock>{}, AK1_c),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1_c, AK1_c, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0_, Number<NPerBlock>{}, BK1_),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1_, BK1_, I1));
make_tuple(BK0_c, Number<NPerBlock>{}, BK1_c),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1_c, BK1_c, I1));
}
__host__ __device__ static constexpr auto
......@@ -505,7 +509,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1_, BK1_);
constexpr auto max_lds_align = math::lcm(AK1_c, BK1_c);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
......@@ -728,7 +732,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1_, BK1_);
constexpr auto max_lds_align = math::lcm(AK1_c, BK1_c);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
......@@ -742,7 +746,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0_, MPerBlock, AK1_>,
Sequence<AK0_c, MPerBlock, AK1_c>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
......@@ -773,7 +777,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0_, NPerBlock, BK1_>,
Sequence<BK0_c, NPerBlock, BK1_c>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
......@@ -806,7 +810,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// register
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1_, BK1_),
math::max(math::lcm(AK1_c, BK1_c),
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
......@@ -835,8 +839,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1_, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1_, 0, 0);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1_c, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1_c, 0, 0);
// gridwise GEMM pipeline
static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment