Commit 457ee9a1 authored by Chao Liu's avatar Chao Liu
Browse files

update v4r4 nhwc

parent fe325666
...@@ -150,6 +150,10 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -150,6 +150,10 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// c_block_cluster_desc
const auto gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over a_k_m_global tensor // hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_k_m_global_iterator_hacks = constexpr auto a_k_m_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
...@@ -189,6 +193,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -189,6 +193,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
decltype(gemm_block_cluster_desc),
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
...@@ -256,6 +261,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -256,6 +261,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*, FloatC*,
decltype(gemm_block_cluster_desc),
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -270,6 +276,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -270,6 +276,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
p_in_global, p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global, p_out_global,
gemm_block_cluster_desc,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
...@@ -284,6 +291,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -284,6 +291,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*, FloatC*,
decltype(gemm_block_cluster_desc),
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -298,6 +306,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -298,6 +306,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
p_in_global, p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global, p_out_global,
gemm_block_cluster_desc,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
...@@ -312,6 +321,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -312,6 +321,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*, FloatC*,
decltype(gemm_block_cluster_desc),
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -326,6 +336,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -326,6 +336,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
p_in_global, p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global, p_out_global,
gemm_block_cluster_desc,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
...@@ -340,6 +351,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -340,6 +351,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*, FloatC*,
decltype(gemm_block_cluster_desc),
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -354,6 +366,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -354,6 +366,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
p_in_global, p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global, p_out_global,
gemm_block_cluster_desc,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
...@@ -522,6 +535,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -522,6 +535,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
} }
}; };
#if 0
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
...@@ -1013,6 +1027,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 ...@@ -1013,6 +1027,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
#endif #endif
} }
}; };
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -53,7 +53,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -53,7 +53,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr auto C0 = C / Number<InWeiVectorSize>{}; constexpr auto C0 = C / Number<InWeiVectorSize>{};
constexpr auto C1 = Number<InWeiVectorSize>{}; constexpr auto C1 = Number<InWeiVectorSize>{};
#if 0 #if 1
// run-time variables // run-time variables
constexpr auto in_n_hi_wi_c0_desc = constexpr auto in_n_hi_wi_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0)); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0));
...@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
#if 1 #if 0
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" #include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
//#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
//#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" //#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -724,7 +724,7 @@ int main(int argc, char* argv[]) ...@@ -724,7 +724,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
in_vector_size, in_vector_size,
acc_data_t, acc_data_t,
...@@ -740,7 +740,7 @@ int main(int argc, char* argv[]) ...@@ -740,7 +740,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_vector_size, in_vector_size,
acc_data_t, acc_data_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment