"...composable_kernel_rocm.git" did not exist on "f91579aab6e224c23aceaeaa0a29d9dde83f09ed"
Commit 4774d863 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 5dd45128
...@@ -98,9 +98,26 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -98,9 +98,26 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
using CM0M1N0N1GridDesc = decltype(c_m0_m1_n0_n1_grid_desc); using CM0M1N0N1GridDesc = decltype(c_m0_m1_n0_n1_grid_desc);
#if 0 constexpr auto M1 = Number<MPerBlock>{};
const auto c_m0_m10_m constexpr auto N1 = Number<NPerBlock>{};
#endif
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
constexpr auto N11 = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
// out_gemm_block_cluster_desc // out_gemm_block_cluster_desc
const auto c_block_cluster_desc = const auto c_block_cluster_desc =
...@@ -119,6 +136,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -119,6 +136,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
BKNGridDesc, BKNGridDesc,
CM0M1N0N1GridDesc, CM0M1N0N1GridDesc,
CBlockClusterDesc, CBlockClusterDesc,
CM0M10M11N0N10N11GridDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
...@@ -160,7 +178,6 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -160,7 +178,6 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0; const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = 0; float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
...@@ -173,6 +190,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -173,6 +190,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
remove_reference_t<BKNGridDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>, remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
true, true,
true>; true>;
...@@ -188,7 +206,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -188,7 +206,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc, c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc); c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc);
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
...@@ -200,6 +219,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -200,6 +219,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
remove_reference_t<BKNGridDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>, remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
true, true,
false>; false>;
...@@ -215,7 +235,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -215,7 +235,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc, c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc); c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc);
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
...@@ -227,6 +248,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -227,6 +248,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
remove_reference_t<BKNGridDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>, remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
false, false,
true>; true>;
...@@ -242,7 +264,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -242,7 +264,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc, c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc); c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc);
} }
else else
{ {
...@@ -254,6 +277,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -254,6 +277,7 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
remove_reference_t<BKNGridDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>, remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
false, false,
false>; false>;
...@@ -269,138 +293,11 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -269,138 +293,11 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc, c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc); c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc);
} }
return ave_time; return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k_m_grid_desc_device_buf(sizeof(AKMGridDesc));
DeviceMem b_k_n_grid_desc_device_buf(sizeof(BKNGridDesc));
DeviceMem c_m0_m1_n0_n1_grid_desc_device_buf(sizeof(CM0M1N0N1GridDesc));
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
a_k_m_grid_desc_device_buf.ToDevice(&a_k_m_grid_desc);
b_k_n_grid_desc_device_buf.ToDevice(&b_k_n_grid_desc);
c_m0_m1_n0_n1_grid_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_grid_desc);
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
return ave_time;
#endif
} }
} // namespace ck } // namespace ck
......
...@@ -93,7 +93,8 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( ...@@ -93,7 +93,8 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
// output tensor // output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), make_tuple(make_pass_through_transform(K),
make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
......
...@@ -12,15 +12,15 @@ ...@@ -12,15 +12,15 @@
namespace ck { namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename AKMGridDesc, typename AKMGridDesc,
typename BKNGridDesc, typename BKNGridDesc,
typename CM0M1N0N1GridDesc, typename CM10M11N10N11GridDesc,
typename CBlockClusterDesc, typename CBlockClusterDesc,
typename CM0M10M11N0N10N11GridDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
...@@ -32,69 +32,21 @@ __global__ void ...@@ -32,69 +32,21 @@ __global__ void
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AKMGridDesc a_k_m_grid_desc, const AKMGridDesc a_k_m_grid_desc,
const BKNGridDesc b_k_n_grid_desc, const BKNGridDesc b_k_n_grid_desc,
const CM0M1N0N1GridDesc c_m0_m1_n0_n1_grid_desc, const CM10M11N10N11GridDesc c_m10_n10_m11_n11_grid_desc,
const CBlockClusterDesc c_block_cluster_desc) const CBlockClusterDesc c_block_cluster_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc)
{ {
GridwiseGemm::Run(p_a_grid, GridwiseGemm::Run(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc, c_m10_n10_m11_n11_grid_desc,
c_block_cluster_desc, c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AKMGridDesc,
typename BKNGridDesc,
typename CM0M1N0N1GridDesc,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1r2(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_k_m_grid_desc,
const void __CONSTANT__* p_b_k_n_grid_desc,
const void __CONSTANT__* p_c_m0_m1_n0_n1_grid_desc,
const void __CONSTANT__* p_c_block_cluster_desc)
{
// first cast void __CONSTANT__ void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k_m_grid_desc =
*reinterpret_cast<const AKMGridDesc*>((const void*)p_a_k_m_grid_desc);
const auto b_k_n_grid_desc =
*reinterpret_cast<const BKNGridDesc*>((const void*)p_b_k_n_grid_desc);
const auto c_m0_m1_n0_n1_grid_desc =
*reinterpret_cast<const CM0M1N0N1GridDesc*>((const void*)p_c_m0_m1_n0_n1_grid_desc);
const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#endif
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
...@@ -103,18 +55,19 @@ template <index_t BlockSize, ...@@ -103,18 +55,19 @@ template <index_t BlockSize,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperation CGlobalMemoryDataOperation,
typename AKMGridDesc, typename AKMGridDesc,
typename BKNGridDesc, typename BKNGridDesc,
typename CM0M1N0N1GridDesc, typename CM10M11N10N11GridDesc,
typename CBlockClusterDesc, typename CBlockClusterDesc,
typename CM0M10M11N0N10N11GridDesc,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t M1PerThread, index_t M1PerThread,
index_t N1PerThread, index_t N1PerThread,
index_t KPerThread, index_t KPerThread,
index_t M1N1ThreadClusterM10, index_t M1N1ThreadClusterM100,
index_t M1N1ThreadClusterN10, index_t M1N1ThreadClusterN100,
index_t M1N1ThreadClusterM11, index_t M1N1ThreadClusterM101,
index_t M1N1ThreadClusterN11, index_t M1N1ThreadClusterN101,
typename ABlockTransferThreadSliceLengths_K_M, typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M, typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
...@@ -174,7 +127,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -174,7 +127,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeAKM0M1BlockClusterizedGridDescriptor(const AKMGridDesc& a_k_m_grid_desc) MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc)
{ {
const auto K = a_k_m_grid_desc.GetLength(I0); const auto K = a_k_m_grid_desc.GetLength(I0);
const auto M = a_k_m_grid_desc.GetLength(I1); const auto M = a_k_m_grid_desc.GetLength(I1);
...@@ -192,7 +145,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -192,7 +145,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBKN0N1BlockClusterizedGridDescriptor(const BKNGridDesc& b_k_n_grid_desc) MakeBKN0N1GridDescriptor(const BKNGridDesc& b_k_n_grid_desc)
{ {
const auto K = b_k_n_grid_desc.GetLength(I0); const auto K = b_k_n_grid_desc.GetLength(I0);
const auto N = b_k_n_grid_desc.GetLength(I1); const auto N = b_k_n_grid_desc.GetLength(I1);
...@@ -209,8 +162,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -209,8 +162,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
return b_k_n0_n1_block_clusterized_grid_desc; return b_k_n0_n1_block_clusterized_grid_desc;
} }
using AKM0M1GridDesc = decltype(MakeAKM0M1BlockClusterizedGridDescriptor(AKMGridDesc{})); #if 0
using BKN0N1GridDesc = decltype(MakeBKN0N1BlockClusterizedGridDescriptor(BKNGridDesc{})); __host__ __device__ static constexpr auto
MakeCM0M10M11N0N10N11GridDescriptor(const BKNGridDesc& b_k_n_grid_desc)
{
}
#endif
using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{}));
using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
...@@ -218,8 +178,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -218,8 +178,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AKMGridDesc& a_k_m_grid_desc, const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc, const BKNGridDesc& b_k_n_grid_desc,
const CM0M1N0N1GridDesc& c_m0_m1_n0_n1_grid_desc, const CM10M11N10N11GridDesc& c_m10_n10_m11_n11_grid_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CBlockClusterDesc& c_block_cluster_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
...@@ -229,12 +190,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -229,12 +190,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_grid, b_k_n_grid_desc.GetElementSpaceSize()); p_b_grid, b_k_n_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_grid, c_m0_m1_n0_n1_grid_desc.GetElementSpaceSize()); p_c_grid, c_m10_n10_m11_n11_grid_desc.GetElementSpaceSize());
const auto K = a_k_m_grid_desc.GetLength(I0); const auto K = a_k_m_grid_desc.GetLength(I0);
const auto M = a_k_m_grid_desc.GetLength(I1); const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1); const auto N = b_k_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id())); c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
...@@ -332,17 +299,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -332,17 +299,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
M1PerThread, M1PerThread,
N1PerThread, N1PerThread,
KPerThread, KPerThread,
M1N1ThreadClusterM10, M1N1ThreadClusterM100,
M1N1ThreadClusterN10, M1N1ThreadClusterN100,
M1N1ThreadClusterM11, M1N1ThreadClusterM101,
M1N1ThreadClusterN11, M1N1ThreadClusterN101,
M1PerThread, M1PerThread,
N1PerThread>{}; N1PerThread>{};
constexpr auto c_m0_m1_n0_n1_thread_tensor_lengths = constexpr auto c_m10_n10_m11_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
constexpr auto c_m0_m1_n0_n1_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto c_m10_n10_m11_n11_thread_desc =
sequence_to_tuple_of_number(c_m0_m1_n0_n1_thread_tensor_lengths)); make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(c_m10_n10_m11_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
...@@ -356,12 +324,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -356,12 +324,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize()); c_m10_n10_m11_n11_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_m10_n10_m11_n11_thread_desc),
decltype(c_m0_m1_n0_n1_thread_tensor_lengths)>{} decltype(c_m10_n10_m11_n11_thread_tensor_lengths)>{}
.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); .Run(c_m10_n10_m11_n11_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
...@@ -421,8 +392,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -421,8 +392,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
b_k_n_grid_desc, b_global_buf, b_k_n_global_iterator_hacks); b_k_n_grid_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(c_m10_n10_m11_n11_thread_desc,
c_m0_m1_n0_n1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
...@@ -446,7 +419,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -446,7 +419,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m0_m1_n0_n1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); c_m10_n10_m11_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
...@@ -474,7 +447,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -474,7 +447,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m0_m1_n0_n1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); c_m10_n10_m11_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
...@@ -484,7 +457,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -484,7 +457,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m0_m1_n0_n1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); c_m10_n10_m11_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
} }
else // if has 1 iteration left else // if has 1 iteration left
{ {
...@@ -492,42 +465,95 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -492,42 +465,95 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m0_m1_n0_n1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); c_m10_n10_m11_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
} }
// output: register to global memory // output: register to global memory
{ {
constexpr auto M1 = Number<M1PerThread * M1N1ThreadClusterM10 * M1N1ThreadClusterM11>{}; #if 0
constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN10 * M1N1ThreadClusterN11>{}; constexpr auto M11 = Number<M1PerThread * M1N1ThreadClusterM100 * M1N1ThreadClusterM101>{};
constexpr auto N11 = Number<N1PerThread * M1N1ThreadClusterN100 * M1N1ThreadClusterN101>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack to control index calculation when iterating over c_m10_n10_m11_n11_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGridIteratorHacks{}; constexpr auto c_m10_n10_m11_n11_global_tensor_iterator_hacks = CGridIteratorHacks{};
const auto c_thread_data_idx_on_block = const auto c_thread_data_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginIndex(get_thread_local_1d_id()); blockwise_gemm.CalculateCM0M1N0N1ThreadOriginIndex(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3<FloatAcc, ThreadwiseDynamicTensorSliceTransfer_v1r3<FloatAcc,
FloatC, FloatC,
decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_m10_n10_m11_n11_thread_desc),
decltype(c_m0_m1_n0_n1_grid_desc), decltype(c_m10_n10_m11_n11_grid_desc),
decltype(c_m0_m1_n0_n1_thread_tensor_lengths), decltype(c_m10_n10_m11_n11_thread_tensor_lengths),
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{ true>{
c_m0_m1_n0_n1_grid_desc, c_m10_n10_m11_n11_grid_desc,
make_multi_index(m_block_data_idx_on_global / M1 + c_thread_data_idx_on_block[I0], make_multi_index(m_block_data_idx_on_global / M11 + c_thread_data_idx_on_block[I0],
c_thread_data_idx_on_block[I1], c_thread_data_idx_on_block[I1],
n_block_data_idx_on_global / N1 + c_thread_data_idx_on_block[I2], n_block_data_idx_on_global / N11 + c_thread_data_idx_on_block[I2],
c_thread_data_idx_on_block[I3])} c_thread_data_idx_on_block[I3])}
.Run(c_m0_m1_n0_n1_thread_desc, .Run(c_m10_n10_m11_n11_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_m0_m1_n0_n1_grid_desc, c_m10_n10_m11_n11_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks); c_m10_n10_m11_n11_global_tensor_iterator_hacks);
#else
constexpr index_t M11 = M1PerThread * M1N1ThreadClusterM100 * M1N1ThreadClusterM101;
constexpr index_t N11 = N1PerThread * M1N1ThreadClusterN100 * M1N1ThreadClusterN101;
constexpr index_t M10 = MPerBlock / M11;
constexpr index_t N10 = NPerBlock / N11;
constexpr index_t M111 = M1PerThread;
constexpr index_t N111 = N1PerThread;
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<1>{},
Number<c_m10_n10_m11_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_n10_m11_n11_thread_tensor_lengths[I1]>{},
Number<1>{},
Number<c_m10_n10_m11_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_n10_m11_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginIndex(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
Sequence<1,
c_m10_n10_m11_n11_thread_tensor_lengths[I0],
c_m10_n10_m11_n11_thread_tensor_lengths[I1],
1,
c_m10_n10_m11_n11_thread_tensor_lengths[I2],
c_m10_n10_m11_n11_thread_tensor_lengths[I3]>,
Sequence<3, 4, 5, 0, 1, 2>, // TODO: CThreadTransferSrcDstAccessOrder
5, // TODO: CThreadTransferSrcDstVectorDim
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
make_multi_index(__builtin_amdgcn_readfirstlane(block_work_idx[I0]),
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
__builtin_amdgcn_readfirstlane(block_work_idx[I1]),
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_grid_buf,
CGridIteratorHacks{});
#endif
} }
} }
...@@ -537,8 +563,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -537,8 +563,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AKMGridDesc& a_k_m_grid_desc, const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc, const BKNGridDesc& b_k_n_grid_desc,
const CM0M1N0N1GridDesc& c_m0_m1_n0_n1_grid_desc, const CM10M11N10N11GridDesc& c_m10_n10_m11_n11_grid_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CBlockClusterDesc& c_block_cluster_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
...@@ -551,8 +578,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -551,8 +578,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
p_c_grid, p_c_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc, c_m10_n10_m11_n11_grid_desc,
c_block_cluster_desc, c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
p_shared_block, p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#endif #endif
// launch bounds // launch bounds
#define CK_USE_LAUNCH_BOUNDS 1 #define CK_USE_LAUNCH_BOUNDS 0
#ifdef CK_USE_LAUNCH_BOUNDS #ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256 #define CK_MAX_THREAD_PER_BLOCK 256
......
...@@ -499,6 +499,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -499,6 +499,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
constexpr auto in_gemmk_gemmn_grid_move_slice_window_iterator_hacks = constexpr auto in_gemmk_gemmn_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
#if 0
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_grid // hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_grid
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks = constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
...@@ -509,6 +510,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -509,6 +510,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 2, 0, 0>{}));
#else
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
#endif
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
...@@ -553,7 +569,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -553,7 +569,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
GemmCThreadTransferDstScalarPerVector_GemmN1, GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(wei_gemmk_gemmm_grid_iterator_hacks), decltype(wei_gemmk_gemmm_grid_iterator_hacks),
decltype(in_gemmk_gemmn_grid_iterator_hacks), decltype(in_gemmk_gemmn_grid_iterator_hacks),
#if 0
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks), decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks),
#else
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
#endif
decltype(wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks), decltype(wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks),
decltype(in_gemmk_gemmn_grid_move_slice_window_iterator_hacks)>( decltype(in_gemmk_gemmn_grid_move_slice_window_iterator_hacks)>(
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
...@@ -566,7 +586,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -566,7 +586,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
out_gemmm_gemmn_grid_desc, out_gemmm_gemmn_grid_desc,
wei_gemmk_gemmm_grid_iterator_hacks, wei_gemmk_gemmm_grid_iterator_hacks,
in_gemmk_gemmn_grid_iterator_hacks, in_gemmk_gemmn_grid_iterator_hacks,
#if 0
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks, out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks,
#else
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
#endif
wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks, wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks,
in_gemmk_gemmn_grid_move_slice_window_iterator_hacks, in_gemmk_gemmn_grid_move_slice_window_iterator_hacks,
nrepeat); nrepeat);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment