Commit d99e020d authored by Chao Liu's avatar Chao Liu
Browse files

overhauling fwd-v4r4

parent 4b21c0fd
...@@ -121,44 +121,10 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( ...@@ -121,44 +121,10 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{})); make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor
constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk_gemmn_global tensor
constexpr auto in_gemmk_gemmn_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
return make_tuple(wei_gemmk_gemmm_global_desc, return make_tuple(wei_gemmk_gemmm_global_desc,
in_gemmk_gemmn_global_desc, in_gemmk_gemmn_global_desc,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
out_gemm_block_cluster_desc, out_gemm_block_cluster_desc);
wei_gemmk_gemmm_global_iterator_hacks,
in_gemmk_gemmn_global_iterator_hacks,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
} }
} // namespace ck } // namespace ck
......
...@@ -74,7 +74,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -74,7 +74,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBKN0N1BlockDescriptor(const BKNBlockDesc& n_k_n_block_desc) MakeBKN0N1BlockDescriptor(const BKNBlockDesc& b_k_n_block_desc)
{ {
const auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor( const auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
BKNBlockDesc{}, BKNBlockDesc{},
......
...@@ -485,6 +485,35 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -485,6 +485,35 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
in_left_pads, in_left_pads,
in_right_pads); in_right_pads);
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor
constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk_gemmn_global tensor
constexpr auto in_gemmk_gemmn_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = launch_kernel_dynamic_gemm_v1r2< float ave_time = launch_kernel_dynamic_gemm_v1r2<
...@@ -527,25 +556,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -527,25 +556,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
Sequence<2, 3, 0, 1>, Sequence<2, 3, 0, 1>,
3, 3,
GemmCThreadTransferDstScalarPerVector_GemmN1, GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(descs[I4]), decltype(wei_gemmk_gemmm_global_iterator_hacks),
decltype(descs[I5]), decltype(in_gemmk_gemmn_global_iterator_hacks),
decltype(descs[I6]), decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks),
decltype(descs[I7]), decltype(wei_gemmk_gemmm_global_move_slice_window_iterator_hacks),
decltype(descs[I8])>(static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( decltype(in_gemmk_gemmn_global_move_slice_window_iterator_hacks)>(
wei_k_c_y_x_device_buf.GetDeviceBuffer()), static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( wei_k_c_y_x_device_buf.GetDeviceBuffer()),
in_n_c_hi_wi_device_buf.GetDeviceBuffer()), static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
descs[I0], static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
descs[I1], descs[I0],
descs[I2], descs[I1],
descs[I3], descs[I2],
descs[I4], descs[I3],
descs[I5], wei_gemmk_gemmm_global_iterator_hacks,
descs[I6], in_gemmk_gemmn_global_iterator_hacks,
descs[I7], out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
descs[I8], wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
nrepeat); in_gemmk_gemmn_global_move_slice_window_iterator_hacks,
nrepeat);
float perf = (float)calculate_convolution_flops( float perf = (float)calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment