Commit f1814e23 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix updating kbatch size.

parent 24c7c49f
......@@ -183,6 +183,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
// Block2CTileMap configuration parameter.
static constexpr index_t B2E_M01 = 8;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
using KernelArgument = typename GridwiseGemm::Argument;
......@@ -205,20 +207,29 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
};
static constexpr index_t DefaultKBatch = 4;
// Argument
struct Argument : public BaseArgument
{
Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs)
: Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch)
{
grid_size_ = 0;
// TODO: use occupancy api to calculate appropriate batch size.
const index_t K_BATCH = 4;
// Block2CTileMap configuration parameter.
constexpr index_t B2C_M01 = 8;
}
Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
index_t kbatch)
: K_BATCH{kbatch}
{
grid_size_ = 0;
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
......@@ -232,7 +243,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
skipped_group_count_ = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
{
const index_t M = gemm_descs[i].M_;
const index_t N = gemm_descs[i].N_;
......@@ -250,14 +261,14 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
const index_t k_padded = GridwiseGemm::CalculateKPadded(K);
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH);
const auto c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(M, N, m_padded, n_padded, stride_c);
const auto local_b2c_tile_map =
Block2ETileMapKSplit{c_grid_desc_m_n, B2C_M01, K_BATCH};
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = grid_size_;
......@@ -289,7 +300,50 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
}
/**
* @brief Recalculate group grid size for all gemms and update B2C maps.
*
* @param[in] kbatch The new splitK parameter value.
*/
void UpdateKBatch(index_t kbatch)
{
K_BATCH = kbatch;
grid_size_ = 0;
for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
{
auto& karg = gemm_kernel_args_[i].karg_;
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH);
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(
karg.M, karg.N, karg.MPadded, karg.NPadded, karg.StrideC);
const auto local_b2c_tile_map =
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
// block-to-e-tile map
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KPadded = k_padded;
karg.K0 = k0;
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_[i].block_start_ = block_start;
gemm_kernel_args_[i].block_end_ = block_end;
}
}
// private:
index_t K_BATCH;
index_t group_count_;
index_t skipped_group_count_;
......@@ -454,11 +508,13 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
const auto& a = arg.gemm_kernel_args_[i].karg_;
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
#if DEBUG_LOG
if(not group_arg_valid)
{
std::cout << "[" << __func__ << "] group id: " << i << " is not supported!\n";
a.Print();
}
#endif // DEBUG_LOG
supported &= group_arg_valid;
}
return supported;
......@@ -542,18 +598,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
sizeof(GemmTransKernelArg);
}
static void SetKBatchSize(const Argument& karg, index_t kbatch)
{
for(auto& arg : karg.gemm_kernel_args_)
{
arg.k_batch = kbatch;
}
}
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
// polymorphic
void SetKBatchSize(const BaseArgument* p_arg, index_t kbatch)
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{
return SetKBatchSize(*dynamic_cast<const Argument*>(p_arg), kbatch);
return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment