Commit bac321a9 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Remove 'block_2_ctile_map' kernel parameter

parent a65d6459
......@@ -243,9 +243,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t M01,
index_t N01)
index_t StrideC)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
......@@ -253,16 +251,13 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
kraw_{K}
{
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
}
// private:
......@@ -273,8 +268,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
index_t kraw_;
};
......@@ -319,15 +312,14 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M_N>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
const auto kernel =
kernel_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M_N>,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -339,20 +331,18 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
arg.c_grid_desc_m_n_);
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M_N>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
const auto kernel =
kernel_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M_N>,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -364,8 +354,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
arg.c_grid_desc_m_n_);
}
return ave_time;
......@@ -438,7 +427,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
BElementwiseOperation,
CElementwiseOperation)
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, 1, 1};
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -465,9 +454,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
K,
StrideA,
StrideB,
StrideC,
1,
1);
StrideC);
}
// polymorphic
......
......@@ -22,7 +22,6 @@ template <typename GridwiseGemm,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
......@@ -33,8 +32,7 @@ __global__ void
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M_N c_grid_desc_m_n,
const Block2CTileMap block_2_ctile_map)
const CGridDesc_M_N c_grid_desc_m_n)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
......@@ -46,8 +44,7 @@ __global__ void
p_shared,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m_n,
block_2_ctile_map);
c_grid_desc_m_n);
#else
ignore = p_a_grid;
ignore = p_b_grid;
......@@ -55,7 +52,6 @@ __global__ void
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m_n;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -293,8 +289,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
......@@ -302,17 +298,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
......@@ -330,6 +325,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto block_2_ctile_map = MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment