Commit 02d23347 authored by Chao Liu's avatar Chao Liu
Browse files

overhauling fwd-v4r4

parent 318db82b
...@@ -482,29 +482,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -482,29 +482,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
const auto in_gemmk_gemmn_grid_desc = descs[I1]; const auto in_gemmk_gemmn_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2]; const auto out_gemmm_gemmn_grid_desc = descs[I2];
const auto GemmM = out_gemmm_gemmn_grid_desc.GetLength(I0);
const auto GemmN = out_gemmm_gemmn_grid_desc.GetLength(I1);
const auto GemmK = wei_gemmk_gemmm_grid_desc.GetLength(I0);
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
const auto GemmM0 = GemmM / Number<GemmM1>{};
const auto GemmN0 = GemmN / Number<GemmN1>{};
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// out_gemm_block_cluster_desc
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_grid tensor // hack to control index calculation when iterating over wei_gemmk_gemmm_grid tensor
constexpr auto wei_gemmk_gemmm_grid_iterator_hacks = constexpr auto wei_gemmk_gemmm_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
...@@ -543,8 +520,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -543,8 +520,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_grid_desc), decltype(wei_gemmk_gemmm_grid_desc),
decltype(in_gemmk_gemmn_grid_desc), decltype(in_gemmk_gemmn_grid_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc), decltype(out_gemmm_gemmn_grid_desc),
decltype(out_gemm_block_cluster_desc),
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
...@@ -587,8 +563,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -587,8 +563,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
wei_gemmk_gemmm_grid_desc, wei_gemmk_gemmm_grid_desc,
in_gemmk_gemmn_grid_desc, in_gemmk_gemmn_grid_desc,
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc, out_gemmm_gemmn_grid_desc,
out_gemm_block_cluster_desc,
wei_gemmk_gemmm_grid_iterator_hacks, wei_gemmk_gemmm_grid_iterator_hacks,
in_gemmk_gemmn_grid_iterator_hacks, in_gemmk_gemmn_grid_iterator_hacks,
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks, out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment