Commit b3a012bc authored by root's avatar root
Browse files

thread mapping

parent 2662f8e5
...@@ -111,24 +111,28 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -111,24 +111,28 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( const auto in_gemmk_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc, in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)), make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))), make_pass_through_transform(N),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_pass_through_transform(Ho),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_pass_through_transform(Wo)),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// output tensor // output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( const auto out_gemmm_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo)),
make_tuple(make_pass_through_transform(K), make_tuple(make_pass_through_transform(K),
make_merge_transform(make_tuple(N, Ho * Wo))), make_pass_through_transform(N),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_pass_through_transform(Ho),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_pass_through_transform(Wo)),
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); const auto GemmM = K;
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); const auto GemmN = N * Ho * Wo;
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); const auto GemmK = C * Y * X;
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0)) GemmK % GemmKPerBlock == 0))
...@@ -136,20 +140,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -136,20 +140,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
constexpr auto GemmM1 = Number<GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster>{};
constexpr auto GemmN1 = Number<GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster>{};
const auto GemmM0 = GemmM / GemmM1;
const auto GemmN0 = GemmN / GemmN1;
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// hack to control index calculation when iterating over a_k_m_global tensor // hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_k_m_global_iterator_hacks = constexpr auto a_k_m_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
...@@ -157,28 +147,32 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -157,28 +147,32 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
// hack to control index calculation when iterating over b_k_n_global tensor
constexpr auto b_k_n_global_iterator_hacks = constexpr auto b_k_n_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto b_k_n_global_move_slice_window_iterator_hack = constexpr auto b_k_n_global_move_slice_window_iterator_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format // hack for NKHW format
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = constexpr auto c_k_n_h_w_global_tensor_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}), Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 0, 0, 0>{}));
#if 1
// GEMM // GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v2< using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v2<
BlockSize, BlockSize,
...@@ -186,8 +180,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -186,8 +180,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
AccFloat, AccFloat,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_n_ho_wo_global_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), decltype(out_gemmm_n_ho_wo_global_desc),
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
...@@ -208,19 +202,19 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -208,19 +202,19 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>, Sequence<3, 2, 1, 0>,
Sequence<0, 1>, Sequence<3, 2, 1, 0>,
1, 3,
GemmBBlockTransferSrcScalarPerVector_GemmN, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation // MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>, Sequence<3, 2, 1, 0>,
3, 3,
GemmCThreadTransferDstScalarPerVector_GemmN1, GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(a_k_m_global_iterator_hacks), decltype(a_k_m_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks), decltype(b_k_n_global_iterator_hacks),
decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks), decltype(c_k_n_h_w_global_tensor_iterator_hacks),
decltype(a_k_m_global_move_slice_window_iterator_hack), decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>; decltype(b_k_n_global_move_slice_window_iterator_hack)>;
...@@ -230,7 +224,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -230,7 +224,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
#if 1 // pass tensor descriptors by their reference
index_t nrepeat = 100; index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
...@@ -248,10 +241,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -248,10 +241,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const Float*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_n_ho_wo_global_desc),
const Float*, const Float*,
decltype( decltype(out_gemmm_n_ho_wo_global_desc),
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, Float*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -263,9 +255,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -263,9 +255,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
0, 0,
wei_gemmk_gemmm_global_desc, wei_gemmk_gemmm_global_desc,
p_wei_global, p_wei_global,
in_gemmk_gemmn_global_desc, in_gemmk_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, out_gemmm_n_ho_wo_global_desc,
p_out_global, p_out_global,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
...@@ -276,10 +268,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -276,10 +268,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const Float*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_n_ho_wo_global_desc),
const Float*, const Float*,
decltype( decltype(out_gemmm_n_ho_wo_global_desc),
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, Float*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -291,9 +282,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -291,9 +282,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
0, 0,
wei_gemmk_gemmm_global_desc, wei_gemmk_gemmm_global_desc,
p_wei_global, p_wei_global,
in_gemmk_gemmn_global_desc, in_gemmk_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, out_gemmm_n_ho_wo_global_desc,
p_out_global, p_out_global,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
...@@ -304,10 +295,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -304,10 +295,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const Float*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_n_ho_wo_global_desc),
const Float*, const Float*,
decltype( decltype(out_gemmm_n_ho_wo_global_desc),
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, Float*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -319,9 +309,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -319,9 +309,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
0, 0,
wei_gemmk_gemmm_global_desc, wei_gemmk_gemmm_global_desc,
p_wei_global, p_wei_global,
in_gemmk_gemmn_global_desc, in_gemmk_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, out_gemmm_n_ho_wo_global_desc,
p_out_global, p_out_global,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
...@@ -332,10 +322,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -332,10 +322,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const Float*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_n_ho_wo_global_desc),
const Float*, const Float*,
decltype( decltype(out_gemmm_n_ho_wo_global_desc),
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, Float*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -347,323 +336,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -347,323 +336,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
0, 0,
wei_gemmk_gemmm_global_desc, wei_gemmk_gemmm_global_desc,
p_wei_global, p_wei_global,
in_gemmk_gemmn_global_desc, in_gemmk_n_ho_wo_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif 1 // pass tensor descriptors by their pointers
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*,
const Float*,
decltype(in_gemmk_gemmn_global_desc)*,
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*,
const Float*,
decltype(in_gemmk_gemmn_global_desc)*,
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*,
const Float*,
decltype(in_gemmk_gemmn_global_desc)*,
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*,
const Float*,
decltype(in_gemmk_gemmn_global_desc)*,
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif 1 // pass tensor descriptor by void*
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*,
const Float*,
const void*,
const Float*,
const void*,
Float*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*,
const Float*,
const void*,
const Float*,
const void*,
Float*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*,
const Float*,
const void*,
const Float*,
const void*,
Float*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*,
const Float*,
const void*,
const Float*,
const void*,
Float*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global, p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf out_gemmm_n_ho_wo_global_desc,
.GetDeviceBuffer(),
p_out_global, p_out_global,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
......
...@@ -68,15 +68,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -68,15 +68,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_cyx_n_h_w_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<1>{}, Number<8>{}, Number<8>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size = constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_cyx_n_h_w_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float); return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
} }
...@@ -84,9 +84,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -84,9 +84,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc, __device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc, const BGlobalDesc& b_cyx_n_h_w_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CGlobalDesc& c_k_n_h_w_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block, Float* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
...@@ -97,7 +97,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -97,7 +97,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const auto K = a_k_m_global_desc.GetLength(I0); const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1); const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1); const auto N = b_cyx_n_h_w_global_desc.GetLength(I1);
// divide block work by [M, N] // divide block work by [M, N]
#if 0 #if 0
...@@ -118,7 +118,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -118,7 +118,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
#endif #endif
const index_t m_block_data_on_global = m_block_work_id * MPerBlock; const index_t m_block_data_on_global = m_block_work_id * MPerBlock;
const index_t n_block_data_on_global = n_block_work_id * NPerBlock;
const index_t h_block_data_on_global = n_block_work_id * 8;
const index_t w_block_data_on_global = n_block_work_id * 8;
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
...@@ -133,8 +135,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -133,8 +135,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_cyx_n_h_w_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<1>{}, Number<8>{}, Number<8>{}), max_lds_align);
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
...@@ -166,33 +168,52 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -166,33 +168,52 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
make_multi_index(0, 0)); make_multi_index(0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4<
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock>, Sequence<KPerBlock, 1, 8, 8>, // BlockSliceLengths
BBlockTransferThreadSliceLengths_K_N, Sequence<KPerBlock, 1, 1, 1>, // ThreadSliceLengths_K_N
BBlockTransferThreadClusterLengths_K_N, Sequence<1, 1, 8, 8>, // ThreadClusterLengths_K_N
BBlockTransferThreadClusterArrangeOrder, Sequence<3, 2, 0, 1>, // ThreadClusterArrangeOrder
Float, Float,
Float, Float,
decltype(b_k_n_global_desc), decltype(b_cyx_n_h_w_global_desc), // SrcDesc
decltype(b_k_n_block_desc), decltype(b_cyx_n_h_w_block_desc), // DstDesc
BBlockTransferSrcAccessOrder, Sequence<3, 2, 0, 1>, // SrcDimAccessOrder
Sequence<0, 1>, Sequence<3, 2, 0, 1>, // DstDimAccessOrder
BBlockTransferSrcVectorDim, 3, // SrcVectorDim
1, 3, // DstVectorDim
BBlockTransferSrcScalarPerVector, 1, // SrcScalarPerVector
BBlockTransferDstScalarPerVector_N, 1, // DstScalarPerVector
AddressSpace::Global, AddressSpace::Global,
AddressSpace::Lds, AddressSpace::Lds,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(b_cyx_n_h_w_global_desc,
b_k_n_global_desc, make_multi_index(0, 0, h_block_data_on_global, w_block_data_on_global),
make_multi_index(0, n_block_data_on_global), b_cyx_n_h_w_block_desc,
b_k_n_block_desc, make_multi_index(0, 0, 0, 0));
make_multi_index(0, 0));
#if 0
constexpr auto b_cyx_n_h_w_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<NPerThread>{}));
using BThreadwiseTransfer =
ThreadwiseDynamicTensorSliceTransfer_v2<Float,
Float,
decltype(b_cyx_n_h_w_global_desc),
decltype(b_cyx_n_h_w_thread_desc),
Sequence<KPerThread, NPerThread>,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>;
#endif
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -201,70 +222,77 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -201,70 +222,77 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 && //static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0, //NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!"); //"wrong!");
constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster); // constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); // constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_m0m1_n0n1_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto c_k_n_h_w_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{})); make_tuple(Number<MPerThread>{}, Number<1>{}, Number<1>{}, Number<1>{}));
#if 0
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize, BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
decltype(a_k_m_block_desc), decltype(a_k_m_block_desc),
decltype(b_k_n_block_desc), decltype(b_cyx_n_h_w_block_desc),
decltype(c_m0m1_n0n1_thread_desc), decltype(c_k_n_h_w_thread_desc),
MPerThread, MPerThread,
NPerThread, NPerThread,
KPerThread, KPerThread,
MLevel0Cluster, MLevel0Cluster,
NLevel0Cluster, NLevel0Cluster,
MLevel1Cluster, MLevel1Cluster,
NLevel1Cluster, NLevel1Cluster,
MPerThread, 1,
NPerThread>{}; 1>{};
#endif
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size = constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_cyx_n_h_w_block_desc.GetElementSpaceSize(), max_lds_align);
Float* p_a_block_double = p_shared_block; Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size; Float* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output // register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()]; AccFloat p_c_thread[c_k_n_h_w_thread_desc.GetElementSpaceSize()];
for(index_t i = 0; i < c_k_n_h_w_thread_desc.GetElementSpaceSize(); i++)
{
p_c_thread[i] = 0;
}
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread); // threadwise_matrix_set_zero_v2(c_k_n_h_w_thread_desc, p_c_thread);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{}; constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_k_n_global_iterator_hacks = BGlobalIteratorHacks{}; constexpr auto b_cyx_n_h_w_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for // hack to control index calculation when move slice window for A and B matrix for
// threadwise copy // threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack = constexpr auto a_k_m_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{}; AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_k_n_global_move_slice_window_iterator_hack = constexpr auto b_cyx_n_h_w_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_blockwise_copy.RunRead(b_cyx_n_h_w_global_desc, p_b_global, b_cyx_n_h_w_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double); b_blockwise_copy.RunWrite(b_cyx_n_h_w_block_desc, p_b_block_double);
} }
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
...@@ -285,9 +313,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -285,9 +313,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack); a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
// b_blockwise_copy.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc,
// b_block_slice_copy_step,
// b_cyx_n_h_w_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack); b_cyx_n_h_w_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
...@@ -295,22 +328,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -295,22 +328,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_cyx_n_h_w_global_desc, p_b_global, b_cyx_n_h_w_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread); // blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_odd); b_blockwise_copy.RunWrite(b_cyx_n_h_w_block_desc, p_b_block_odd);
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack); a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_blockwise_copy.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack); b_cyx_n_h_w_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
...@@ -318,14 +351,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -318,14 +351,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_cyx_n_h_w_global_desc, p_b_global, b_cyx_n_h_w_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread); // blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_even); b_blockwise_copy.RunWrite(b_cyx_n_h_w_block_desc, p_b_block_even);
k_block_data_begin += 2 * KPerBlock; k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock); } while(k_block_data_begin < K - 2 * KPerBlock);
...@@ -337,53 +370,49 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -337,53 +370,49 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack); a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_blockwise_copy.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack); b_cyx_n_h_w_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_blockwise_copy.RunRead(b_cyx_n_h_w_global_desc, p_b_global, b_cyx_n_h_w_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); // blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size); b_blockwise_copy.RunWrite(b_cyx_n_h_w_block_desc, p_b_block_double + b_block_space_size);
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double + a_block_space_size, // blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size, // p_b_block_double + b_block_space_size,
p_c_thread); // p_c_thread);
} }
else // if has 1 iteration left else // if has 1 iteration left
{ {
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); // blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
} }
#if 1
// output: register to global memory // output: register to global memory
{ {
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
// define input tensor descriptor for threadwise copy // define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy // thread input tensor, src of threadwise copy
constexpr auto c_m0_m1_n0_n1_thread_desc = constexpr auto c_k_n_h_w_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<MRepeat>{}, make_tuple(Number<MPerThread>{}, Number<1>{}, Number<1>{}, Number<1>{}));
Number<MPerThread>{},
Number<NRepeat>{},
Number<NPerThread>{}));
// calculate origin of thread input tensor on global memory // calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
#if 0
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
...@@ -392,47 +421,54 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -392,47 +421,54 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const index_t n_thread_data_on_global = const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col; n_block_data_on_global + c_thread_mtx_on_block.col;
#endif
const index_t h_thread_id = get_thread_local_1d_id() / 8;
const index_t w_thread_id = get_thread_local_1d_id() % 8;
const index_t m_thread_data_on_global = m_block_data_on_global;
const index_t h_thread_data_on_global = h_block_data_on_global + h_thread_id;
const index_t w_thread_data_on_global = w_block_data_on_global + w_thread_id;
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack to control index calculation when iterating over c_k_n_h_w_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_k_n_h_w_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
constexpr auto tmp = make_unmerge_transform(make_tuple( // constexpr auto tmp = make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{})); // Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
AccFloat, AccFloat,
Float, Float,
decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_k_n_h_w_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc), decltype(c_k_n_h_w_global_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>, Sequence<MPerThread, 1, 1, 1>,
CThreadTransferSrcDstAccessOrder, Sequence<3, 2, 0, 1>, // CThreadTransferSrcDstAccessOrder
CThreadTransferSrcDstVectorDim, 3, // CThreadTransferSrcDstVectorDim
CThreadTransferDstScalarPerVector, 1, // CThreadTransferDstScalarPerVector,
AddressSpace::Vgpr, AddressSpace::Vgpr,
AddressSpace::Global, AddressSpace::Global,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>(c_m0_m1_n0_n1_global_desc, true>(
make_multi_index(m_thread_data_on_global / M1, c_k_n_h_w_global_desc,
m_thread_data_on_global % M1, make_multi_index(
n_thread_data_on_global / N1, m_thread_data_on_global, 0, h_thread_data_on_global, w_thread_data_on_global))
n_thread_data_on_global % N1)) .Run(c_k_n_h_w_thread_desc,
.Run(c_m0_m1_n0_n1_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_c_thread, p_c_thread,
c_m0_m1_n0_n1_global_desc, c_k_n_h_w_global_desc,
p_c_global, p_c_global,
c_m0_m1_n0_n1_global_tensor_iterator_hacks); c_k_n_h_w_global_tensor_iterator_hacks);
} }
#endif
} }
// pass tensor descriptor by reference // pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc, __device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc, const BGlobalDesc& b_cyx_n_h_w_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CGlobalDesc& c_k_n_h_w_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
...@@ -443,9 +479,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -443,9 +479,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
Run(a_k_m_global_desc, Run(a_k_m_global_desc,
p_a_global, p_a_global,
b_k_n_global_desc, b_cyx_n_h_w_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc, c_k_n_h_w_global_desc,
p_c_global, p_c_global,
p_shared_block, p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
...@@ -456,22 +492,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -456,22 +492,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_k_m_global_desc, __device__ void Run(const AGlobalDesc* p_a_k_m_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const BGlobalDesc* p_b_k_n_global_desc, const BGlobalDesc* p_b_cyx_n_h_w_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
const CGlobalDesc* p_c_m0_m1_n0_n1_global_desc, const CGlobalDesc* p_c_k_n_h_w_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
const auto a_k_m_global_desc = *p_a_k_m_global_desc; const auto a_k_m_global_desc = *p_a_k_m_global_desc;
const auto b_k_n_global_desc = *p_b_k_n_global_desc; const auto b_cyx_n_h_w_global_desc = *p_b_cyx_n_h_w_global_desc;
const auto c_m0_m1_n0_n1_global_desc = *p_c_m0_m1_n0_n1_global_desc; const auto c_k_n_h_w_global_desc = *p_c_k_n_h_w_global_desc;
Run(a_k_m_global_desc, Run(a_k_m_global_desc,
p_a_global, p_a_global,
b_k_n_global_desc, b_cyx_n_h_w_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc, c_k_n_h_w_global_desc,
p_c_global, p_c_global,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
...@@ -481,23 +517,23 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -481,23 +517,23 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_k_m_global_desc, __device__ void Run(const void* p_a_k_m_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const void* p_b_k_n_global_desc, const void* p_b_cyx_n_h_w_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
const void* p_c_m0_m1_n0_n1_global_desc, const void* p_c_k_n_h_w_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
const auto a_k_m_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_k_m_global_desc); const auto a_k_m_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_k_m_global_desc);
const auto b_k_n_global_desc = *reinterpret_cast<const BGlobalDesc*>(p_b_k_n_global_desc); const auto b_cyx_n_h_w_global_desc = *reinterpret_cast<const BGlobalDesc*>(p_b_cyx_n_h_w_global_desc);
const auto c_m0_m1_n0_n1_global_desc = const auto c_k_n_h_w_global_desc =
*reinterpret_cast<const CGlobalDesc*>(p_c_m0_m1_n0_n1_global_desc); *reinterpret_cast<const CGlobalDesc*>(p_c_k_n_h_w_global_desc);
Run(a_k_m_global_desc, Run(a_k_m_global_desc,
p_a_global, p_a_global,
b_k_n_global_desc, b_cyx_n_h_w_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc, c_k_n_h_w_global_desc,
p_c_global, p_c_global,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
......
...@@ -136,7 +136,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -136,7 +136,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
// loop over tensor and copy // loop over tensor and copy
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) { static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// judge move forward or move backward // judge move forward or move backward
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep; StaticallyIndexedArray<bool, nDim> forward_sweep;
...@@ -463,7 +462,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -463,7 +462,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
// loop over tensor and copy // loop over tensor and copy
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) { static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// judge move forward or move backward // judge move forward or move backward
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep; StaticallyIndexedArray<bool, nDim> forward_sweep;
...@@ -500,7 +498,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -500,7 +498,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
}(); }();
// copy data // copy data
static_assert(DstAddressSpace == AddressSpace::Vgpr, "wrong! hardcode for ds_read"); static_assert(DstAddressSpace == AddressSpace::Vgpr, "wrong! hardcode for vgpr dst");
vector_type<SrcData, SrcScalarPerVector> src_vector; vector_type<SrcData, SrcScalarPerVector> src_vector;
...@@ -798,7 +796,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -798,7 +796,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// loop over tensor and copy // loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) { static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward // judge move forward or move backward
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep; StaticallyIndexedArray<bool, nDim> forward_sweep;
...@@ -978,7 +975,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -978,7 +975,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// loop over tensor and copy // loop over tensor and copy
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) { static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward // judge move forward or move backward
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep; StaticallyIndexedArray<bool, nDim> forward_sweep;
......
...@@ -75,9 +75,9 @@ struct ThreadwiseGemm_km_kn_mn_v1 ...@@ -75,9 +75,9 @@ struct ThreadwiseGemm_km_kn_mn_v1
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto M = CDesc{}[I0]; constexpr auto M = CDesc{}.GetLength(I0);
constexpr auto N = CDesc{}[I1]; constexpr auto N = CDesc{}.GetLength(I1);
constexpr auto K = ADesc{}[I0]; constexpr auto K = ADesc{}.GetLength(I0);
static_for<0, K, 1>{}([&](auto k) { static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) { static_for<0, M, 1>{}([&](auto m) {
......
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
#endif #endif
#ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM #ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 1 #define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 0
#endif #endif
#ifndef CK_USE_AMD_V_FMAC_F32 #ifndef CK_USE_AMD_V_FMAC_F32
......
...@@ -67,7 +67,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -67,7 +67,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif #endif
#if 1
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -75,17 +74,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -75,17 +74,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t GemmNPerBlock = 64; constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2; constexpr index_t GemmMPerThread = 16;
constexpr index_t GemmNPerThread = 2; constexpr index_t GemmNPerThread = 1;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 1;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 1;
constexpr index_t GemmMLevel1Cluster = 2; constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 8; constexpr index_t GemmNLevel1Cluster = 64;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
...@@ -99,265 +95,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -99,265 +95,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
#elif 0
// cdata = 16, BlockSize = 64, 16x64x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 2
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 2;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 16>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
#elif 0
// cdata = 32, BlockSize = 64, 16x128x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 0
// cdata = 64, BlockSize = 128, 32x256x8
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x2
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x4
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
// b thread copy 4x1
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
// b thread copy 2x2
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x16
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#endif
constexpr auto conv_driver = constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad< DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
......
...@@ -82,8 +82,8 @@ int main(int argc, char* argv[]) ...@@ -82,8 +82,8 @@ int main(int argc, char* argv[])
#elif 1 #elif 1
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 4;
constexpr index_t HI = 1080; constexpr index_t HI = 8;
constexpr index_t WI = 1920; constexpr index_t WI = 8;
constexpr index_t K = 16; constexpr index_t K = 16;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -657,7 +657,7 @@ int main(int argc, char* argv[]) ...@@ -657,7 +657,7 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
#if 0 #if 1
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 0 #elif 0
...@@ -776,9 +776,9 @@ int main(int argc, char* argv[]) ...@@ -776,9 +776,9 @@ int main(int argc, char* argv[])
} }
check_error(out_nkhw_host, out_nkhw_device); check_error(out_nkhw_host, out_nkhw_device);
#if 0 #if 1
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; // LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl; // LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment