Unverified Commit 88833bd9 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Merge pull request #32 from ROCmSoftwarePlatform/develop

Merge develop into master
parents 31b40352 f3acd251
...@@ -159,7 +159,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -159,7 +159,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
template <typename... Xs> template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs) __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{ {
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...); return Tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
} }
} // namespace ck } // namespace ck
......
...@@ -14,9 +14,7 @@ struct is_known_at_compile_time<Tuple<Ts...>> ...@@ -14,9 +14,7 @@ struct is_known_at_compile_time<Tuple<Ts...>>
return container_reduce( return container_reduce(
Tuple<Ts...>{}, Tuple<Ts...>{},
[](auto x, bool r) { [](auto x, bool r) {
return is_known_at_compile_time< return is_known_at_compile_time<remove_cvref_t<decltype(x)>>::value & r;
remove_cv_t<remove_reference_t<decltype(x)>>>::value &
r;
}, },
true); true);
} }
......
...@@ -22,6 +22,9 @@ using remove_reference_t = typename std::remove_reference<T>::type; ...@@ -22,6 +22,9 @@ using remove_reference_t = typename std::remove_reference<T>::type;
template <typename T> template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type; using remove_cv_t = typename std::remove_cv<T>::type;
template <typename T>
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template <typename T> template <typename T>
inline constexpr bool is_pointer_v = std::is_pointer<T>::value; inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
......
...@@ -374,13 +374,8 @@ extern "C" __global__ void ...@@ -374,13 +374,8 @@ extern "C" __global__ void
CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{}, CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{},
CGridBlockCluster_BlockId_To_GM10_GN10{})); CGridBlockCluster_BlockId_To_GM10_GN10{}));
const auto desc_tuple = *reinterpret_cast<const DescTuple*>( const auto desc_tuple =
#pragma clang diagnostic push *reinterpret_cast<const DescTuple*>(cast_pointer_to_generic_address_space(p_desc_tuple));
#pragma clang diagnostic ignored "-Wold-style-cast"
// TODO: how to cast?
(const void*)p_desc_tuple
#pragma clang diagnostic pop
);
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0]; const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0];
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1]; const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1];
......
...@@ -13,9 +13,15 @@ include_directories(BEFORE ...@@ -13,9 +13,15 @@ include_directories(BEFORE
set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp) set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp)
set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp) set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp)
set(CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp)
set(GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp)
add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE}) add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE})
add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE}) add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE})
add_executable(conv_wrw_driver_offline ${CONV_WRW_DRIVER_OFFLINE_SOURCE})
add_executable(gemm_driver_offline ${GEMM_DRIVER_OFFLINE_SOURCE})
target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor) target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor)
target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor) target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor)
target_link_libraries(conv_wrw_driver_offline PRIVATE host_tensor)
target_link_libraries(gemm_driver_offline PRIVATE host_tensor)
...@@ -56,8 +56,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -56,8 +56,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4; constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 2;
...@@ -84,8 +84,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -84,8 +84,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 2;
...@@ -112,8 +112,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -112,8 +112,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4; constexpr index_t MRepeat = 4;
...@@ -140,8 +140,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -140,8 +140,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 2;
...@@ -168,8 +168,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -168,8 +168,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4; constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4; constexpr index_t MRepeat = 4;
...@@ -208,40 +208,42 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -208,40 +208,42 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( // clang-format off
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 3+: NWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 7+: N1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: NRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 3-: NWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
//clang-format on
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
...@@ -263,8 +265,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -263,8 +265,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerWave, GemmMPerXDL,
GemmNPerWave, GemmNPerXDL,
GemmK1, GemmK1,
MRepeat, MRepeat,
NRepeat, NRepeat,
...@@ -289,7 +291,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -289,7 +291,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
GemmCThreadTransferDstScalarPerVector, GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(in_m0_m1_m2_n_grid_step_hacks), decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat false // CAccessOrderMRepeatNRepeat
...@@ -301,7 +303,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -301,7 +303,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
in_gemmm_gemmn_grid_desc, in_gemmm_gemmn_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks, wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_step_hacks, out_gemmk0_gemmn_gemmk1_grid_step_hacks,
in_m0_m1_m2_n_grid_step_hacks, in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat); nrepeat);
......
...@@ -195,25 +195,27 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -195,25 +195,27 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( // clang-format off
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 2+: MWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 4+: M0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 5+: M1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 6+: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: MRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 2-: MWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 4-: M0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 5-: M1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
//clang-format on
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
...@@ -265,7 +267,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -265,7 +267,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
GemmCThreadTransferDstScalarPerVector, GemmCThreadTransferDstScalarPerVector,
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(in_m0_m1_m2_n_grid_step_hacks), decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
true // CAccessOrderMRepeatNRepeat true // CAccessOrderMRepeatNRepeat
...@@ -277,7 +279,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -277,7 +279,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
in_gemmm_gemmn_grid_desc, in_gemmm_gemmn_grid_desc,
out_gemmk0_gemmm_gemmk1_grid_step_hacks, out_gemmk0_gemmm_gemmk1_grid_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
in_m0_m1_m2_n_grid_step_hacks, in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat); nrepeat);
......
#include <unistd.h> #include <unistd.h>
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" #include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
#include "driver_gemm_xdlops_v2r2.hpp" #include "driver_gemm_xdlops_v2r3.hpp"
template <typename TInWei, template <typename TInWei,
typename TAcc, typename TAcc,
...@@ -14,17 +14,17 @@ template <typename TInWei, ...@@ -14,17 +14,17 @@ template <typename TInWei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
const InLengths& in_n_hi_wi_c_lengths, const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_y_x_c_lengths, const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_ho_wo_k_lengths, const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c, const Tensor<TInWei>& in_n_c_hi_wi,
const Tensor<TInWei>& wei_k_y_x_c, Tensor<TInWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_ho_wo_k, const Tensor<TOut>& out_n_k_ho_wo,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -34,138 +34,149 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( ...@@ -34,138 +34,149 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 1 #if 1
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64; constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 64; constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1; constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1 #elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256; constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64; constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 64; constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 1; constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif #endif
const auto descs = const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, wei_k_c_y_x_desc,
in_n_hi_wi_c_desc, in_n_c_hi_wi_desc,
out_n_ho_wo_k_desc, out_n_k_ho_wo_desc,
conv_strides, conv_strides,
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
Number<GemmK1>{}); Number<GemmK1>{});
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2]; const auto wei_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0>{}, // 0+: GemmK0
make_tuple( Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 1, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 2, 0, 0>{})); // 2-: GemmK1
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1
constexpr auto out_m0_m1_m2_n_grid_step_hacks = constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 1, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 1, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{};
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = driver_gemm_xdlops_v2r2< float ave_time = driver_gemm_xdlops_v2r3<
BlockSize, BlockSize,
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(wei_gemmm_gemmn_grid_desc),
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerWave, GemmMPerWave,
GemmNPerWave, GemmNPerWave,
GemmK1,
MRepeat, MRepeat,
NRepeat, NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
...@@ -181,49 +192,37 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( ...@@ -181,49 +192,37 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
2, 2,
GemmBBlockTransferSrcScalarPerVector_GemmK1, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1, GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1>, Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
2, 7,
GemmCThreadTransferDstScalarPerVector, GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_m1_m2_n_grid_step_hacks), decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks)>( decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()), false>(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
wei_gemmk0_gemmm_gemmk1_grid_desc, out_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc, in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc, wei_gemmm_gemmn_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks, out_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_step_hacks, in_gemmk0_gemmn_gemmk1_grid_step_hacks,
out_m0_m1_m2_n_grid_step_hacks, wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat); nrepeat);
{ float perf = static_cast<float>(calculate_convolution_flops(
const auto N = out_n_ho_wo_k_lengths[I0]; in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time; (std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
<< std::endl;
}
} }
// copy result back to host // copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data());
} }
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c_hi_wi,
const Tensor<TInWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_k_ho_wo,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 0
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 8;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 0
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 0
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 4]
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 4]
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 4;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#endif
const auto descs =
#if 1
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
#else
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1
#endif
<TInWei, GemmMPerBlock, GemmNPerBlock, GemmMPerWave, GemmNPerWave, GemmKPack>(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
for(index_t i = 0; i < 5; ++i)
{
#if 0
float ave_time = launch_kernel_gemm_xdlops_v1
#else
float ave_time = launch_kernel_gemm_xdlops_v2
#endif
<BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(descs[I0]),
decltype(descs[I1]),
decltype(descs[I2]),
decltype(descs[I3]),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmKPack,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_KPack,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<0, 2, 1>,
Sequence<1, 0, 2>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_KPack,
false, // don't move back src coordinate after threadwise copy, which will be fused
// with MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(descs[I4]),
decltype(descs[I5]),
decltype(descs[I6]),
decltype(descs[I7]),
decltype(descs[I8])>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
descs[I0],
descs[I1],
descs[I2],
descs[I3],
descs[I4],
descs[I5],
descs[I6],
descs[I7],
descs[I8],
nrepeat);
float perf = (float)calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
}
...@@ -47,7 +47,35 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( ...@@ -47,7 +47,35 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 1 #if 0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16 // [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -92,36 +120,39 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( ...@@ -92,36 +120,39 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
const auto out_gemmm_gemmn_grid_desc = descs[I2]; const auto out_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
make_tuple( Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto out_m0_m1_m2_n_grid_step_hacks = constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 1, 0, 0>{}), Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0>{};
...@@ -169,7 +200,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( ...@@ -169,7 +200,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
GemmCThreadTransferDstScalarPerVector, GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_m1_m2_n_grid_step_hacks), decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()), false>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
...@@ -180,7 +211,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( ...@@ -180,7 +211,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
out_gemmm_gemmn_grid_desc, out_gemmm_gemmn_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks, wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_step_hacks, in_gemmk0_gemmn_gemmk1_grid_step_hacks,
out_m0_m1_m2_n_grid_step_hacks, out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat); nrepeat);
......
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 1
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 0
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
in_n_hi_wi_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_xdlops_v2r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmK1,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
6,
GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_m1_m2_n_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
out_m0_m1_m2_n_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Hi = in_n_hi_wi_c_lengths[I1];
const auto Wi = in_n_hi_wi_c_lengths[I2];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
}
...@@ -56,8 +56,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( ...@@ -56,8 +56,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4; constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4; constexpr index_t MRepeat = 4;
...@@ -84,8 +84,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( ...@@ -84,8 +84,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4; constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 2;
...@@ -112,8 +112,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( ...@@ -112,8 +112,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4; constexpr index_t MRepeat = 4;
...@@ -140,8 +140,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( ...@@ -140,8 +140,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4; constexpr index_t MRepeat = 4;
...@@ -168,8 +168,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( ...@@ -168,8 +168,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 2;
...@@ -196,8 +196,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( ...@@ -196,8 +196,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 2;
...@@ -249,23 +249,23 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( ...@@ -249,23 +249,23 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto out_m0_m1_m2_n_grid_step_hacks = constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0>{}, // 3+: NWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0>{}, // 4+: M0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0>{}, // 5+: M1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0>{}, // 6+: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0>{}), // 7+: N1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: MRepeat make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0>{}, // 1-: NRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0>{}, // 2-: MWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0>{}, // 3-: NWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0>{}, // 4-: M0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0>{}, // 5-: M1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0>{}, // 6-: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0>{})); // 7-: N1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
...@@ -287,8 +287,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( ...@@ -287,8 +287,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerWave, GemmMPerXDL,
GemmNPerWave, GemmNPerXDL,
GemmK1, GemmK1,
MRepeat, MRepeat,
NRepeat, NRepeat,
...@@ -313,7 +313,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( ...@@ -313,7 +313,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
GemmCThreadTransferDstScalarPerVector, GemmCThreadTransferDstScalarPerVector,
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_m1_m2_n_grid_step_hacks), decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat false // CAccessOrderMRepeatNRepeat
...@@ -325,7 +325,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( ...@@ -325,7 +325,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
out_gemmm_gemmn_grid_desc, out_gemmm_gemmn_grid_desc,
in_gemmk0_gemmm_gemmk1_grid_step_hacks, in_gemmk0_gemmm_gemmk1_grid_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
out_m0_m1_m2_n_grid_step_hacks, out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat); nrepeat);
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -17,8 +17,8 @@ template <ck::index_t BlockSize, ...@@ -17,8 +17,8 @@ template <ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
ck::index_t KPerBlock, ck::index_t KPerBlock,
ck::index_t MPerWave, ck::index_t MPerXDL,
ck::index_t NPerWave, ck::index_t NPerXDL,
ck::index_t K1, ck::index_t K1,
ck::index_t MRepeat, ck::index_t MRepeat,
ck::index_t NRepeat, ck::index_t NRepeat,
...@@ -79,8 +79,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -79,8 +79,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
MPerWave, MPerXDL,
NPerWave, NPerXDL,
K1, K1,
MRepeat, MRepeat,
NRepeat, NRepeat,
...@@ -129,9 +129,10 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -129,9 +129,10 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
} }
const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc); using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
...@@ -144,7 +145,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -144,7 +145,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
FloatC, FloatC,
remove_reference_t<AK0MK1GridDesc>, remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>, remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<CM0M1M2NGridDesc>, remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>>; remove_reference_t<CBlockClusterAdaptor>>;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
...@@ -158,18 +159,18 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -158,18 +159,18 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
p_c_grid, p_c_grid,
a_k0_m_k1_grid_desc, a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc, b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc)); DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc));
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time = launch_and_time_kernel( float ave_time = launch_and_time_kernel(
...@@ -183,7 +184,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -183,7 +184,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
#endif #endif
return ave_time; return ave_time;
......
...@@ -41,7 +41,7 @@ int main(int argc, char* argv[]) ...@@ -41,7 +41,7 @@ int main(int argc, char* argv[])
// dynamic mode // dynamic mode
if(argc != 22) if(argc != 22)
{ {
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n");
exit(1); exit(1);
} }
...@@ -79,7 +79,7 @@ int main(int argc, char* argv[]) ...@@ -79,7 +79,7 @@ int main(int argc, char* argv[])
// static mode // static mode
if(argc < 7) if(argc < 7)
{ {
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
exit(1); exit(1);
} }
...@@ -90,28 +90,28 @@ int main(int argc, char* argv[]) ...@@ -90,28 +90,28 @@ int main(int argc, char* argv[])
const bool do_log = std::stoi(argv[5]); const bool do_log = std::stoi(argv[5]);
const int nrepeat = std::stoi(argv[6]); const int nrepeat = std::stoi(argv[6]);
constexpr index_t N = 128; constexpr auto N = Number<128>{};
constexpr index_t C = 192; constexpr auto C = Number<192>{};
constexpr index_t Hi = 71; constexpr auto Hi = Number<71>{};
constexpr index_t Wi = 71; constexpr auto Wi = Number<71>{};
constexpr index_t K = 256; constexpr auto K = Number<256>{};
constexpr index_t Y = 3; constexpr auto Y = Number<3>{};
constexpr index_t X = 3; constexpr auto X = Number<3>{};
const index_t conv_stride_h = 2; constexpr auto conv_stride_h = I2;
const index_t conv_stride_w = 2; constexpr auto conv_stride_w = I2;
const index_t conv_dilation_h = 1; constexpr auto conv_dilation_h = I1;
const index_t conv_dilation_w = 1; constexpr auto conv_dilation_w = I1;
const index_t in_left_pad_h = 1; constexpr auto in_left_pad_h = I1;
const index_t in_left_pad_w = 1; constexpr auto in_left_pad_w = I1;
const index_t in_right_pad_h = 1; constexpr auto in_right_pad_h = I1;
const index_t in_right_pad_w = 1; constexpr auto in_right_pad_w = I1;
const index_t YEff = (Y - 1) * conv_dilation_h + 1; constexpr auto YEff = (Y - I1) * conv_dilation_h + I1;
const index_t XEff = (X - 1) * conv_dilation_w + 1; constexpr auto XEff = (X - I1) * conv_dilation_w + I1;
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1;
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
#endif #endif
#if 0 #if 0
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment