Commit 318db82b authored by Chao Liu's avatar Chao Liu
Browse files

overhauling fwd-v4r4

parent d99e020d
...@@ -10,11 +10,7 @@ namespace ck { ...@@ -10,11 +10,7 @@ namespace ck {
// GemmM = K // GemmM = K
// GemmN = N * Ho * Wo // GemmN = N * Ho * Wo
// GemmK = C * Y * X // GemmK = C * Y * X
template <index_t GemmMPerBlock, template <typename... Wei,
index_t GemmNPerBlock,
index_t GemmM1,
index_t GemmN1,
typename... Wei,
typename... In, typename... In,
typename... Out, typename... Out,
typename ConvStrides, typename ConvStrides,
...@@ -101,30 +97,8 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( ...@@ -101,30 +97,8 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); return make_tuple(
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
const auto GemmM0 = GemmM / Number<GemmM1>{};
const auto GemmN0 = GemmN / Number<GemmN1>{};
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// out_gemm_block_cluster_desc
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
return make_tuple(wei_gemmk_gemmm_global_desc,
in_gemmk_gemmn_global_desc,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
out_gemm_block_cluster_desc);
} }
} // namespace ck } // namespace ck
......
...@@ -469,42 +469,61 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -469,42 +469,61 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#endif #endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
const auto in_gemmk_gemmn_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
const auto GemmM = out_gemmm_gemmn_grid_desc.GetLength(I0);
const auto GemmN = out_gemmm_gemmn_grid_desc.GetLength(I1);
const auto GemmK = wei_gemmk_gemmm_grid_desc.GetLength(I0);
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster; constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
const auto descs = assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad<GemmMPerBlock,
GemmNPerBlock, const auto GemmM0 = GemmM / Number<GemmM1>{};
GemmM1, const auto GemmN0 = GemmN / Number<GemmN1>{};
GemmN1>(
wei_k_c_y_x_desc, const auto out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_desc, out_gemmm_gemmn_grid_desc,
out_n_k_ho_wo_desc, make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
conv_strides, make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
conv_dilations, make_tuple(Sequence<0>{}, Sequence<1>{}),
in_left_pads, make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
in_right_pads);
// out_gemm_block_cluster_desc
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
constexpr auto wei_gemmk_gemmm_global_iterator_hacks = make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_grid tensor
constexpr auto wei_gemmk_gemmm_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{}; constexpr auto wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk_gemmn_global tensor // hack to control index calculation when iterating over in_gemmk_gemmn_grid tensor
constexpr auto in_gemmk_gemmn_global_iterator_hacks = constexpr auto in_gemmk_gemmn_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = constexpr auto in_gemmk_gemmn_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global // hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_grid
// tensor hack for NKHW format constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks =
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
...@@ -522,10 +541,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -522,10 +541,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
decltype(descs[I0]), decltype(wei_gemmk_gemmm_grid_desc),
decltype(descs[I1]), decltype(in_gemmk_gemmn_grid_desc),
decltype(descs[I2]), decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc),
decltype(descs[I3]), decltype(out_gemm_block_cluster_desc),
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
...@@ -556,25 +575,25 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -556,25 +575,25 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
Sequence<2, 3, 0, 1>, Sequence<2, 3, 0, 1>,
3, 3,
GemmCThreadTransferDstScalarPerVector_GemmN1, GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(wei_gemmk_gemmm_global_iterator_hacks), decltype(wei_gemmk_gemmm_grid_iterator_hacks),
decltype(in_gemmk_gemmn_global_iterator_hacks), decltype(in_gemmk_gemmn_grid_iterator_hacks),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks), decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks),
decltype(wei_gemmk_gemmm_global_move_slice_window_iterator_hacks), decltype(wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks),
decltype(in_gemmk_gemmn_global_move_slice_window_iterator_hacks)>( decltype(in_gemmk_gemmn_grid_move_slice_window_iterator_hacks)>(
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c_y_x_device_buf.GetDeviceBuffer()), wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()), in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
descs[I0], wei_gemmk_gemmm_grid_desc,
descs[I1], in_gemmk_gemmn_grid_desc,
descs[I2], out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc,
descs[I3], out_gemm_block_cluster_desc,
wei_gemmk_gemmm_global_iterator_hacks, wei_gemmk_gemmm_grid_iterator_hacks,
in_gemmk_gemmn_global_iterator_hacks, in_gemmk_gemmn_grid_iterator_hacks,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks, out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks, wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks,
in_gemmk_gemmn_global_move_slice_window_iterator_hacks, in_gemmk_gemmn_grid_move_slice_window_iterator_hacks,
nrepeat); nrepeat);
float perf = (float)calculate_convolution_flops( float perf = (float)calculate_convolution_flops(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment