Commit 822856e1 authored by Jing Zhang's avatar Jing Zhang
Browse files

rename kperwave to kpack

parent 5ac70ce0
......@@ -16,7 +16,7 @@ template <typename FloatAB,
index_t GemmNPerBlock,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmKPerWave,
index_t GemmKPack,
typename... Wei,
typename... In,
typename... Out,
......@@ -110,7 +110,7 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, GemmMPerWave, GemmNPerWave, GemmKPerWave>{};
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, GemmMPerWave, GemmNPerWave, GemmKPack>{};
constexpr auto CLayout = xdlops_gemm.GetCLayout();
......
......@@ -23,7 +23,7 @@ template <index_t BlockSize,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
index_t KPack,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K_M,
......@@ -100,7 +100,7 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
KPerBlock,
MPerWave,
NPerWave,
KPerWave,
KPack,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K_M,
......
......@@ -88,7 +88,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPerWave = 1;
constexpr index_t GemmKPack = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>;
......@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPerWave = 4;
constexpr index_t GemmKPack = 4;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1;
......@@ -138,7 +138,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
GemmNPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmKPerWave>(
GemmKPack>(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
......@@ -164,7 +164,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmKPerWave,
GemmKPack,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment