"examples/community/pipeline_stable_diffusion_pag.py" did not exist on "0e0db625d0b7da2e6c27336845514b4460bd0000"
Commit 1814c9c1 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Reduce more temlate params of kernel

parent ea025f07
...@@ -312,12 +312,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -312,12 +312,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, Argument, true>;
kernel_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
Argument,
true>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -331,12 +326,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -331,12 +326,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
} }
else else
{ {
const auto kernel = const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, Argument, false>;
kernel_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
Argument,
false>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
......
...@@ -16,18 +16,14 @@ ...@@ -16,18 +16,14 @@
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm, typename Argument, bool HasMainKBlockLoop>
typename FloatAB,
typename FloatC,
typename Argument,
bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const typename GridwiseGemm::FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, typename GridwiseGemm::FloatC* __restrict__ p_c_grid,
const Argument karg) const Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
...@@ -44,9 +40,9 @@ __global__ void ...@@ -44,9 +40,9 @@ __global__ void
} }
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB_,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC_,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
...@@ -98,6 +94,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -98,6 +94,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
using FloatAB = FloatAB_;
using FloatC = FloatC_;
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateGridSize(index_t M, index_t N) __host__ static auto CalculateGridSize(index_t M, index_t N)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment