Commit a3b31a92 authored by ltqin's avatar ltqin
Browse files

driver variale name

parent 149296c0
...@@ -158,7 +158,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -158,7 +158,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc); return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc);
} }
template <typename CGMNGridDesc> template <typename CGMNGridDesc>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGM0N0M1N1M2M3M4N2GridDescriptor(const CGMNGridDesc& c_g_m_n_grid_desc) MakeCGM0N0M1N1M2M3M4N2GridDescriptor(const CGMNGridDesc& c_g_m_n_grid_desc)
{ {
...@@ -168,8 +168,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -168,8 +168,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(make_pass_through_transform(G), make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)), make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return xdlops_gemm.MakeCGM0N0M1N1M2M3M4N2Descriptor(c_g_m0_n0_m1_n1_m2_n2_grid_desc); return xdlops_gemm.MakeCGM0N0M1N1M2M3M4N2Descriptor(c_g_m0_n0_m1_n1_m2_n2_grid_desc);
} }
......
...@@ -11,9 +11,9 @@ template <ck::index_t BlockSize, ...@@ -11,9 +11,9 @@ template <ck::index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AK0MK1GridDesc, typename AGK0MK1GridDesc,
typename BK0NK1GridDesc, typename BGK0NK1GridDesc,
typename CMNGridDesc, typename CGMNGridDesc,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
ck::index_t KPerBlock, ck::index_t KPerBlock,
...@@ -50,9 +50,9 @@ template <ck::index_t BlockSize, ...@@ -50,9 +50,9 @@ template <ck::index_t BlockSize,
__host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid, __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
const FloatAB* p_b_grid, const FloatAB* p_b_grid,
FloatC* p_c_grid, FloatC* p_c_grid,
const AK0MK1GridDesc& a_g_k0_m_k1_grid_desc, const AGK0MK1GridDesc& a_g_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_g_k0_n_k1_grid_desc, const BGK0NK1GridDesc& b_g_k0_n_k1_grid_desc,
const CMNGridDesc& c_g_m_n_grid_desc, const CGMNGridDesc& c_g_m_n_grid_desc,
AGridStepHacks, AGridStepHacks,
BGridStepHacks, BGridStepHacks,
CGridStepHacks, CGridStepHacks,
...@@ -69,14 +69,14 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid, ...@@ -69,14 +69,14 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
using GridwiseGemm = using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1<BlockSize, GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
AK0MK1GridDesc, AGK0MK1GridDesc,
BK0NK1GridDesc, BGK0NK1GridDesc,
CMNGridDesc, CGMNGridDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
...@@ -134,26 +134,26 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid, ...@@ -134,26 +134,26 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
} }
const auto c_gemmg_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = const auto c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_g_m_n_grid_desc); GridwiseGemm::MakeCGM0N0M1N1M2M3M4N2GridDescriptor(c_g_m_n_grid_desc);
/* using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); using CGM0N0M1N1M2M3M4N2GridDesc = decltype(c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_g_m_n_grid_desc);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); const index_t grid_size = GridwiseGemm::CalculateGridSize(c_g_m_n_grid_desc);
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, const auto kernel = kernel_gemm_xdlops_v3r1<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0MK1GridDesc>, remove_reference_t<AGK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>, remove_reference_t<BGK0NK1GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>, remove_reference_t<CGM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>>; remove_reference_t<CBlockClusterAdaptor>>;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE /* #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = launch_and_time_kernel(kernel, float ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment