Commit 9fdc3fc8 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Remove no-longer used data members

parent 5581dc00
......@@ -385,22 +385,11 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
N,
GridwiseGemm::CalculateNPadded(N),
StrideC)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
kraw_{K}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
}
}
// private:
......@@ -413,9 +402,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
......@@ -446,10 +432,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
#endif
if(!GridwiseGemm::CheckValidity(karg.a_grid_desc_ak0_m_ak1_,
karg.b_grid_desc_bk0_n_bk1_,
karg.c_grid_desc_m_n_,
karg.block_2_ctile_map_))
if(!GridwiseGemm::CheckValidity(
karg.a_grid_desc_ak0_m_ak1_, karg.b_grid_desc_bk0_n_bk1_, karg.c_grid_desc_m_n_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
......@@ -463,8 +447,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
GridwiseGemm,
const auto kernel =
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
......@@ -472,12 +456,10 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
DeviceOp::CGridDesc_M_N,
true>;
ave_time =
launch_and_time_kernel(stream_config,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
......@@ -490,13 +472,12 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
karg.c_element_op_,
karg.a_grid_desc_ak0_m_ak1_,
karg.b_grid_desc_bk0_n_bk1_,
karg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
karg.block_2_ctile_map_);
karg.c_grid_desc_m_n_);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
GridwiseGemm,
const auto kernel =
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
......@@ -504,11 +485,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
DeviceOp::CGridDesc_M_N,
false>;
ave_time =
launch_and_time_kernel(stream_config,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
......@@ -521,8 +500,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
karg.c_element_op_,
karg.a_grid_desc_ak0_m_ak1_,
karg.b_grid_desc_bk0_n_bk1_,
karg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
karg.block_2_ctile_map_);
karg.c_grid_desc_m_n_);
}
return ave_time;
......@@ -558,10 +536,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_);
}
// polymorphic
......
......@@ -25,8 +25,7 @@ template <typename GridwiseGemm,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap,
typename CGridDesc_M_N,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
......@@ -40,9 +39,7 @@ __global__ void
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map)
const CGridDesc_M_N c_grid_desc_m_n)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -56,8 +53,7 @@ __global__ void
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
c_grid_desc_m_n);
#else
ignore = p_a_grid;
ignore = p_b_grid;
......@@ -67,8 +63,7 @@ __global__ void
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_ctile_map;
ignore = c_grid_desc_m_n;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -144,7 +139,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__ __device__ static auto CalculateGridSize(index_t M, index_t N)
{
// reference the implementation of class 'BlockToCTileMap_M00_N0_M01Adapt'
return std::make_tuple(DefaultBlock2CTileMap::CalculateGridSize(M, N), 1, 1);
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
}
__host__ __device__ static auto CalculateMPadded(index_t M)
......@@ -269,12 +264,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
const CGridDesc_M_N& c_grid_desc_m_n)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
......@@ -298,11 +291,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return false;
}
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
......@@ -335,7 +323,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
......@@ -344,10 +332,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
using Block2CTileMap = remove_cvref_t<decltype(MakeBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap>
template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
......@@ -357,10 +344,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map)
const CGridDesc_M_N c_grid_desc_m_n)
{
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -369,6 +357,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_2_ctile_map = MakeBlock2CTileMap(c_grid_desc_m_n);
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment