Commit a3b31a92 authored by ltqin's avatar ltqin
Browse files

driver variale name

parent 149296c0
......@@ -158,7 +158,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc);
}
template <typename CGMNGridDesc>
template <typename CGMNGridDesc>
__host__ __device__ static constexpr auto
MakeCGM0N0M1N1M2M3M4N2GridDescriptor(const CGMNGridDesc& c_g_m_n_grid_desc)
{
......@@ -168,8 +168,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return xdlops_gemm.MakeCGM0N0M1N1M2M3M4N2Descriptor(c_g_m0_n0_m1_n1_m2_n2_grid_desc);
}
......
......@@ -11,9 +11,9 @@ template <ck::index_t BlockSize,
typename FloatAcc,
typename FloatC,
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CMNGridDesc,
typename AGK0MK1GridDesc,
typename BGK0NK1GridDesc,
typename CGMNGridDesc,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
......@@ -50,9 +50,9 @@ template <ck::index_t BlockSize,
__host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AK0MK1GridDesc& a_g_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_g_k0_n_k1_grid_desc,
const CMNGridDesc& c_g_m_n_grid_desc,
const AGK0MK1GridDesc& a_g_k0_m_k1_grid_desc,
const BGK0NK1GridDesc& b_g_k0_n_k1_grid_desc,
const CGMNGridDesc& c_g_m_n_grid_desc,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
......@@ -69,14 +69,14 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
constexpr auto I3 = Number<3>{};
using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1<BlockSize,
GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
AGK0MK1GridDesc,
BGK0NK1GridDesc,
CGMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
......@@ -134,26 +134,26 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
}
const auto c_gemmg_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_g_m_n_grid_desc);
const auto c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCGM0N0M1N1M2M3M4N2GridDescriptor(c_g_m_n_grid_desc);
/* using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
using CGM0N0M1N1M2M3M4N2GridDesc = decltype(c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_g_m_n_grid_desc);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_g_m_n_grid_desc);
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
const auto kernel = kernel_gemm_xdlops_v3r1<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<AGK0MK1GridDesc>,
remove_reference_t<BGK0NK1GridDesc>,
remove_reference_t<CGM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>>;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
/* #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment