Commit bed6f33c authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Move argument field computing logic into device op side

parent 2a43fc3b
...@@ -209,7 +209,20 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -209,7 +209,20 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t StrideB, index_t StrideB,
index_t StrideC) index_t StrideC)
{ {
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K),
GridwiseGemm::CalculateAK0(K),
GridwiseGemm::CalculateBK0(K)};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -236,7 +249,12 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -236,7 +249,12 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC); StrideC,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K),
GridwiseGemm::CalculateAK0(K),
GridwiseGemm::CalculateBK0(K));
} }
// polymorphic // polymorphic
......
...@@ -91,10 +91,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -91,10 +91,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto AK0_ = Number<KPerBlock / AK1Value>{}; static constexpr auto AK0_c = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0_ = Number<KPerBlock / BK1Value>{}; static constexpr auto BK0_c = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1_ = Number<AK1Value>{}; static constexpr auto AK1_c = Number<AK1Value>{};
static constexpr auto BK1_ = Number<BK1Value>{}; static constexpr auto BK1_c = Number<BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
...@@ -398,7 +398,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -398,7 +398,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
index_t K_, index_t K_,
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_) index_t StrideC_,
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t AK0_,
index_t BK0_)
: p_a_grid{p_a_grid_}, : p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_}, p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}, p_c_grid{p_c_grid_},
...@@ -408,17 +413,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -408,17 +413,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
StrideA{StrideA_}, StrideA{StrideA_},
StrideB{StrideB_}, StrideB{StrideB_},
StrideC{StrideC_}, StrideC{StrideC_},
MPadded{CalculateMPadded(M_)}, MPadded{MPadded_},
NPadded{CalculateNPadded(N_)}, NPadded{NPadded_},
KPadded{CalculateKPadded(K_)}, KPadded{KPadded_},
AK0{CalculateAK0(K_)}, AK0{AK0_},
BK0{CalculateBK0(K_)}, BK0{BK0_},
a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1( a_grid_desc_ak0_m_ak1{
M_, CalculateMPadded(M_), K_, CalculateKPadded(K_), StrideA_, CalculateAK0(K_))}, MakeAGridDescriptor_AK0_M_AK1(M_, MPadded_, K_, KPadded_, StrideA_, AK0_)},
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1( b_grid_desc_bk0_n_bk1{
K_, CalculateKPadded(K_), N_, CalculateNPadded(N_), StrideB_, CalculateBK0(K_))}, MakeBGridDescriptor_BK0_N_BK1(K_, KPadded_, N_, NPadded_, StrideB_, BK0_)},
c_grid_desc_m_n{MakeCGridDescriptor_M_N( c_grid_desc_m_n{MakeCGridDescriptor_M_N(M_, MPadded_, N_, NPadded_, StrideC_)}
M_, CalculateMPadded(M_), N_, CalculateNPadded(N_), StrideC_)}
{ {
} }
...@@ -470,16 +474,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -470,16 +474,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(AK0_, Number<MPerBlock>{}, AK1_), make_tuple(AK0_c, Number<MPerBlock>{}, AK1_c),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1_, AK1_, I1)); make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1_c, AK1_c, I1));
} }
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{ {
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(BK0_, Number<NPerBlock>{}, BK1_), make_tuple(BK0_c, Number<NPerBlock>{}, BK1_c),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1_, BK1_, I1)); make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1_c, BK1_c, I1));
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -505,7 +509,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -505,7 +509,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(AK1_, BK1_); constexpr auto max_lds_align = math::lcm(AK1_c, BK1_c);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
...@@ -728,7 +732,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -728,7 +732,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(AK1_, BK1_); constexpr auto max_lds_align = math::lcm(AK1_c, BK1_c);
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
...@@ -742,7 +746,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -742,7 +746,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<AK0_, MPerBlock, AK1_>, Sequence<AK0_c, MPerBlock, AK1_c>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -773,7 +777,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -773,7 +777,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<BK0_, NPerBlock, BK1_>, Sequence<BK0_c, NPerBlock, BK1_c>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -806,7 +810,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -806,7 +810,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// register // register
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1_, BK1_), math::max(math::lcm(AK1_c, BK1_c),
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
...@@ -835,8 +839,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -835,8 +839,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1_, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1_c, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1_, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1_c, 0, 0);
// gridwise GEMM pipeline // gridwise GEMM pipeline
static_assert(std::is_default_constructible_v<GridwiseGemmPipe>); static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment