Commit 02d23347 authored by Chao Liu's avatar Chao Liu
Browse files

overhauling fwd-v4r4

parent 318db82b
......@@ -13,10 +13,9 @@ template <index_t BlockSize,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
typename AKMGridDesc,
typename BKNGridDesc,
typename CMNGridDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
......@@ -46,23 +45,22 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
__host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
const FloatAB* p_b_global,
FloatC* p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
__host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
......@@ -71,23 +69,41 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1);
const auto K = a_k_m_grid_desc.GetLength(I0);
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
constexpr auto M1 = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
const auto M1 = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
const auto N1 = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto c_m0_m1_n0_n1_grid_desc =
transform_dynamic_tensor_descriptor(c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M1)),
make_unmerge_transform(make_tuple(N0, N1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
using CM0M1N0N1GridDesc = decltype(c_m0_m1_n0_n1_grid_desc);
// out_gemm_block_cluster_desc
const auto c_block_cluster_desc =
make_cluster_descriptor_v2(make_tuple(M / Number<MPerBlock>{}, N / Number<NPerBlock>{}));
using CBlockClusterDesc = decltype(c_block_cluster_desc);
// GEMM
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2<BlockSize,
......@@ -95,9 +111,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGlobalDesc,
BGlobalDesc,
CGlobalDesc,
AKMGridDesc,
BKNGridDesc,
CM0M1N0N1GridDesc,
CBlockClusterDesc,
MPerBlock,
NPerBlock,
......@@ -128,11 +144,11 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks>;
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
......@@ -149,9 +165,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
true>;
......@@ -162,12 +178,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
p_a_grid,
p_b_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
......@@ -176,9 +192,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
false>;
......@@ -189,12 +205,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
p_a_grid,
p_b_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
......@@ -203,9 +219,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
true>;
......@@ -216,12 +232,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
p_a_grid,
p_b_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc);
}
else
......@@ -230,9 +246,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
false>;
......@@ -243,25 +259,25 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
p_a_grid,
p_b_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc);
}
return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc));
DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc));
DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc));
DeviceMem a_k_m_grid_desc_device_buf(sizeof(AKMGridDesc));
DeviceMem b_k_n_grid_desc_device_buf(sizeof(BKNGridDesc));
DeviceMem c_m0_m1_n0_n1_grid_desc_device_buf(sizeof(CM0M1N0N1GridDesc));
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc);
b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc);
c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc);
a_k_m_grid_desc_device_buf.ToDevice(&a_k_m_grid_desc);
b_k_n_grid_desc_device_buf.ToDevice(&b_k_n_grid_desc);
c_m0_m1_n0_n1_grid_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_grid_desc);
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
float ave_time = 0;
......@@ -272,9 +288,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
true>;
......@@ -286,12 +302,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
......@@ -300,9 +316,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
false>;
......@@ -314,12 +330,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
......@@ -328,9 +344,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
true>;
......@@ -342,12 +358,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else
......@@ -356,9 +372,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
false>;
......@@ -370,12 +386,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k_m_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_grid_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
......
......@@ -482,29 +482,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
const auto in_gemmk_gemmn_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
const auto GemmM = out_gemmm_gemmn_grid_desc.GetLength(I0);
const auto GemmN = out_gemmm_gemmn_grid_desc.GetLength(I1);
const auto GemmK = wei_gemmk_gemmm_grid_desc.GetLength(I0);
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
const auto GemmM0 = GemmM / Number<GemmM1>{};
const auto GemmN0 = GemmN / Number<GemmN1>{};
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// out_gemm_block_cluster_desc
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_grid tensor
constexpr auto wei_gemmk_gemmm_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
......@@ -543,8 +520,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_grid_desc),
decltype(in_gemmk_gemmn_grid_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc),
decltype(out_gemm_block_cluster_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
......@@ -587,8 +563,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
wei_gemmk_gemmm_grid_desc,
in_gemmk_gemmn_grid_desc,
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc,
out_gemm_block_cluster_desc,
out_gemmm_gemmn_grid_desc,
wei_gemmk_gemmm_grid_iterator_hacks,
in_gemmk_gemmn_grid_iterator_hacks,
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment