Commit 1814c9c1 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Reduce more temlate params of kernel

parent ea025f07
......@@ -312,12 +312,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel =
kernel_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
Argument,
true>;
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, Argument, true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -331,12 +326,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
}
else
{
const auto kernel =
kernel_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
Argument,
false>;
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, Argument, false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
......
......@@ -16,18 +16,14 @@
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename Argument,
bool HasMainKBlockLoop>
template <typename GridwiseGemm, typename Argument, bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::FloatAB* __restrict__ p_a_grid,
const typename GridwiseGemm::FloatAB* __restrict__ p_b_grid,
typename GridwiseGemm::FloatC* __restrict__ p_c_grid,
const Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
......@@ -44,9 +40,9 @@ __global__ void
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAB_,
typename FloatAcc,
typename FloatC,
typename FloatC_,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
......@@ -98,6 +94,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
using FloatAB = FloatAB_;
using FloatC = FloatC_;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateGridSize(index_t M, index_t N)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment