Commit 02d23347 authored by Chao Liu's avatar Chao Liu
Browse files

overhauling fwd-v4r4

parent 318db82b
...@@ -13,10 +13,9 @@ template <index_t BlockSize, ...@@ -13,10 +13,9 @@ template <index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc, typename AKMGridDesc,
typename BGlobalDesc, typename BKNGridDesc,
typename CGlobalDesc, typename CMNGridDesc,
typename CBlockClusterDesc,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
...@@ -46,23 +45,22 @@ template <index_t BlockSize, ...@@ -46,23 +45,22 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks, typename AGridIteratorHacks,
typename BGlobalIteratorHacks, typename BGridIteratorHacks,
typename CGlobalIteratorHacks, typename CGridIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowIteratorHacks>
__host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
const FloatAB* p_b_global, const FloatAB* p_b_grid,
FloatC* p_c_global, FloatC* p_c_grid,
const AGlobalDesc& a_k_m_global_desc, const AKMGridDesc& a_k_m_grid_desc,
const BGlobalDesc& b_k_n_global_desc, const BKNGridDesc& b_k_n_grid_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CMNGridDesc& c_m_n_grid_desc,
const CBlockClusterDesc& c_block_cluster_desc, AGridIteratorHacks,
AGlobalIteratorHacks, BGridIteratorHacks,
BGlobalIteratorHacks, CGridIteratorHacks,
CGlobalIteratorHacks, AGridMoveSliceWindowIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks, BGridMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks,
index_t nrepeat) index_t nrepeat)
{ {
...@@ -71,23 +69,41 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -71,23 +69,41 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto M = a_k_m_global_desc.GetLength(I1); const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1); const auto N = b_k_n_grid_desc.GetLength(I1);
const auto K = a_k_m_global_desc.GetLength(I0); const auto K = a_k_m_grid_desc.GetLength(I0);
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{ {
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
constexpr auto M1 = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{}; const auto M1 = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{}; const auto N1 = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0)) if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0))
{ {
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto c_m0_m1_n0_n1_grid_desc =
transform_dynamic_tensor_descriptor(c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M1)),
make_unmerge_transform(make_tuple(N0, N1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
using CM0M1N0N1GridDesc = decltype(c_m0_m1_n0_n1_grid_desc);
// out_gemm_block_cluster_desc
const auto c_block_cluster_desc =
make_cluster_descriptor_v2(make_tuple(M / Number<MPerBlock>{}, N / Number<NPerBlock>{}));
using CBlockClusterDesc = decltype(c_block_cluster_desc);
// GEMM // GEMM
using gridwise_gemm = using gridwise_gemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2<BlockSize, GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2<BlockSize,
...@@ -95,9 +111,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -95,9 +111,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAcc, FloatAcc,
FloatC, FloatC,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
AGlobalDesc, AKMGridDesc,
BGlobalDesc, BKNGridDesc,
CGlobalDesc, CM0M1N0N1GridDesc,
CBlockClusterDesc, CBlockClusterDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
...@@ -128,11 +144,11 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -128,11 +144,11 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGlobalIteratorHacks, AGridIteratorHacks,
BGlobalIteratorHacks, BGridIteratorHacks,
CGlobalIteratorHacks, CGridIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks>; BGridMoveSliceWindowIteratorHacks>;
const auto GridSize = (M / MPerBlock) * (N / NPerBlock); const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
...@@ -149,9 +165,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -149,9 +165,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BGlobalDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CGlobalDesc>, remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
true, true,
true>; true>;
...@@ -162,12 +178,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -162,12 +178,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
p_a_global, p_a_grid,
p_b_global, p_b_grid,
p_c_global, p_c_grid,
a_k_m_global_desc, a_k_m_grid_desc,
b_k_n_global_desc, b_k_n_grid_desc,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc); c_block_cluster_desc);
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
...@@ -176,9 +192,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -176,9 +192,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BGlobalDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CGlobalDesc>, remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
true, true,
false>; false>;
...@@ -189,12 +205,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -189,12 +205,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
p_a_global, p_a_grid,
p_b_global, p_b_grid,
p_c_global, p_c_grid,
a_k_m_global_desc, a_k_m_grid_desc,
b_k_n_global_desc, b_k_n_grid_desc,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc); c_block_cluster_desc);
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
...@@ -203,9 +219,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -203,9 +219,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BGlobalDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CGlobalDesc>, remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
false, false,
true>; true>;
...@@ -216,12 +232,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -216,12 +232,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
p_a_global, p_a_grid,
p_b_global, p_b_grid,
p_c_global, p_c_grid,
a_k_m_global_desc, a_k_m_grid_desc,
b_k_n_global_desc, b_k_n_grid_desc,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc); c_block_cluster_desc);
} }
else else
...@@ -230,9 +246,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -230,9 +246,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BGlobalDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CGlobalDesc>, remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
false, false,
false>; false>;
...@@ -243,25 +259,25 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -243,25 +259,25 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
p_a_global, p_a_grid,
p_b_global, p_b_grid,
p_c_global, p_c_grid,
a_k_m_global_desc, a_k_m_grid_desc,
b_k_n_global_desc, b_k_n_grid_desc,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc); c_block_cluster_desc);
} }
return ave_time; return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc)); DeviceMem a_k_m_grid_desc_device_buf(sizeof(AKMGridDesc));
DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc)); DeviceMem b_k_n_grid_desc_device_buf(sizeof(BKNGridDesc));
DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc)); DeviceMem c_m0_m1_n0_n1_grid_desc_device_buf(sizeof(CM0M1N0N1GridDesc));
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc)); DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc); a_k_m_grid_desc_device_buf.ToDevice(&a_k_m_grid_desc);
b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc); b_k_n_grid_desc_device_buf.ToDevice(&b_k_n_grid_desc);
c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc); c_m0_m1_n0_n1_grid_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_grid_desc);
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc); c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
float ave_time = 0; float ave_time = 0;
...@@ -272,9 +288,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -272,9 +288,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BGlobalDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CGlobalDesc>, remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
true, true,
true>; true>;
...@@ -286,12 +302,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -286,12 +302,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
p_a_global, p_a_grid,
p_b_global, p_b_grid,
p_c_global, p_c_grid,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), (void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), (void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), (void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
...@@ -300,9 +316,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -300,9 +316,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BGlobalDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CGlobalDesc>, remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
true, true,
false>; false>;
...@@ -314,12 +330,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -314,12 +330,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
p_a_global, p_a_grid,
p_b_global, p_b_grid,
p_c_global, p_c_grid,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), (void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), (void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), (void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
...@@ -328,9 +344,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -328,9 +344,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BGlobalDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CGlobalDesc>, remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
false, false,
true>; true>;
...@@ -342,12 +358,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -342,12 +358,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
p_a_global, p_a_grid,
p_b_global, p_b_grid,
p_c_global, p_c_grid,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), (void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), (void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), (void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
} }
else else
...@@ -356,9 +372,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -356,9 +372,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BGlobalDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CGlobalDesc>, remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
false, false,
false>; false>;
...@@ -370,12 +386,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global, ...@@ -370,12 +386,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
p_a_global, p_a_grid,
p_b_global, p_b_grid,
p_c_global, p_c_grid,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), (void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), (void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), (void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
} }
......
...@@ -482,29 +482,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -482,29 +482,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
const auto in_gemmk_gemmn_grid_desc = descs[I1]; const auto in_gemmk_gemmn_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2]; const auto out_gemmm_gemmn_grid_desc = descs[I2];
const auto GemmM = out_gemmm_gemmn_grid_desc.GetLength(I0);
const auto GemmN = out_gemmm_gemmn_grid_desc.GetLength(I1);
const auto GemmK = wei_gemmk_gemmm_grid_desc.GetLength(I0);
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
const auto GemmM0 = GemmM / Number<GemmM1>{};
const auto GemmN0 = GemmN / Number<GemmN1>{};
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// out_gemm_block_cluster_desc
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_grid tensor // hack to control index calculation when iterating over wei_gemmk_gemmm_grid tensor
constexpr auto wei_gemmk_gemmm_grid_iterator_hacks = constexpr auto wei_gemmk_gemmm_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
...@@ -543,8 +520,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -543,8 +520,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_grid_desc), decltype(wei_gemmk_gemmm_grid_desc),
decltype(in_gemmk_gemmn_grid_desc), decltype(in_gemmk_gemmn_grid_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc), decltype(out_gemmm_gemmn_grid_desc),
decltype(out_gemm_block_cluster_desc),
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
...@@ -587,8 +563,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -587,8 +563,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
wei_gemmk_gemmm_grid_desc, wei_gemmk_gemmm_grid_desc,
in_gemmk_gemmn_grid_desc, in_gemmk_gemmn_grid_desc,
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc, out_gemmm_gemmn_grid_desc,
out_gemm_block_cluster_desc,
wei_gemmk_gemmm_grid_iterator_hacks, wei_gemmk_gemmm_grid_iterator_hacks,
in_gemmk_gemmn_grid_iterator_hacks, in_gemmk_gemmn_grid_iterator_hacks,
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks, out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment