Commit cf360b72 authored by ltqin's avatar ltqin
Browse files

create b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 from parameter

parent 8d4b51ca
......@@ -84,7 +84,7 @@ int main(int argc, char* argv[])
// GEMM shape
#if NORMAL_CONFIG
ck::index_t M = 3840;
ck::index_t M = 256;
ck::index_t N = 4096;
ck::index_t K = 4096;
......
......@@ -265,6 +265,9 @@ struct DeviceGemmXdlSkipLds
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_ =
GridwiseGemm::MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(b_grid_desc_k0_n_k1_);
}
}
......@@ -275,6 +278,8 @@ struct DeviceGemmXdlSkipLds
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
......@@ -331,6 +336,7 @@ struct DeviceGemmXdlSkipLds
CDataType,
remove_reference_t<DeviceGemmXdlSkipLds::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSkipLds::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
......@@ -348,6 +354,7 @@ struct DeviceGemmXdlSkipLds
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
......@@ -362,6 +369,7 @@ struct DeviceGemmXdlSkipLds
CDataType,
remove_reference_t<DeviceGemmXdlSkipLds::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSkipLds::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
......@@ -379,6 +387,7 @@ struct DeviceGemmXdlSkipLds
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
......
......@@ -18,6 +18,7 @@ template <typename GridwiseGemm,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation,
typename BElementwiseOperation,
......@@ -34,6 +35,7 @@ __global__ void
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
......@@ -49,6 +51,7 @@ __global__ void
p_shared,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
......@@ -60,6 +63,7 @@ __global__ void
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3;
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
ignore = a_element_op;
ignore = b_element_op;
......@@ -482,6 +486,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
using BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 =
decltype(MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(BGridDesc_K0_N_K1{}));
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void
......@@ -491,6 +497,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
......@@ -575,9 +582,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
true>{};
auto b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(b_grid_desc_k0_n_k1);
const auto wave_id = GetWaveIdx();
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment