Commit 318db82b authored by Chao Liu's avatar Chao Liu
Browse files

overhauling fwd-v4r4

parent d99e020d
......@@ -10,11 +10,7 @@ namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmM1,
index_t GemmN1,
typename... Wei,
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
......@@ -101,30 +97,8 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
const auto GemmM0 = GemmM / Number<GemmM1>{};
const auto GemmN0 = GemmN / Number<GemmN1>{};
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// out_gemm_block_cluster_desc
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
return make_tuple(wei_gemmk_gemmm_global_desc,
in_gemmk_gemmn_global_desc,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
out_gemm_block_cluster_desc);
return make_tuple(
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
}
} // namespace ck
......
......@@ -469,42 +469,61 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
const auto in_gemmk_gemmn_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
const auto GemmM = out_gemmm_gemmn_grid_desc.GetLength(I0);
const auto GemmN = out_gemmm_gemmn_grid_desc.GetLength(I1);
const auto GemmK = wei_gemmk_gemmm_grid_desc.GetLength(I0);
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
const auto descs =
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad<GemmMPerBlock,
GemmNPerBlock,
GemmM1,
GemmN1>(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor
constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
const auto GemmM0 = GemmM / Number<GemmM1>{};
const auto GemmN0 = GemmN / Number<GemmN1>{};
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// out_gemm_block_cluster_desc
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_grid tensor
constexpr auto wei_gemmk_gemmm_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
constexpr auto wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk_gemmn_global tensor
constexpr auto in_gemmk_gemmn_global_iterator_hacks =
// hack to control index calculation when iterating over in_gemmk_gemmn_grid tensor
constexpr auto in_gemmk_gemmn_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
constexpr auto in_gemmk_gemmn_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_grid
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
......@@ -522,10 +541,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
TAcc,
TOut,
InMemoryDataOperation::Set,
decltype(descs[I0]),
decltype(descs[I1]),
decltype(descs[I2]),
decltype(descs[I3]),
decltype(wei_gemmk_gemmm_grid_desc),
decltype(in_gemmk_gemmn_grid_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc),
decltype(out_gemm_block_cluster_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
......@@ -556,25 +575,25 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(wei_gemmk_gemmm_global_iterator_hacks),
decltype(in_gemmk_gemmn_global_iterator_hacks),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks),
decltype(wei_gemmk_gemmm_global_move_slice_window_iterator_hacks),
decltype(in_gemmk_gemmn_global_move_slice_window_iterator_hacks)>(
decltype(wei_gemmk_gemmm_grid_iterator_hacks),
decltype(in_gemmk_gemmn_grid_iterator_hacks),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks),
decltype(wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks),
decltype(in_gemmk_gemmn_grid_move_slice_window_iterator_hacks)>(
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
descs[I0],
descs[I1],
descs[I2],
descs[I3],
wei_gemmk_gemmm_global_iterator_hacks,
in_gemmk_gemmn_global_iterator_hacks,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
in_gemmk_gemmn_global_move_slice_window_iterator_hacks,
wei_gemmk_gemmm_grid_desc,
in_gemmk_gemmn_grid_desc,
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc,
out_gemm_block_cluster_desc,
wei_gemmk_gemmm_grid_iterator_hacks,
in_gemmk_gemmn_grid_iterator_hacks,
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks,
wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks,
in_gemmk_gemmn_grid_move_slice_window_iterator_hacks,
nrepeat);
float perf = (float)calculate_convolution_flops(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment