"vscode:/vscode.git/clone" did not exist on "d3801b7482a155648af64fd43cace6ef3d23fcde"
Commit d3146496 authored by Jing Zhang's avatar Jing Zhang
Browse files

restore nchwc

parent 9f92c019
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp" #include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp"
template <typename TInWei, template <typename TInWei,
ck::index_t InWeiVectorSize,
typename TAcc, typename TAcc,
typename TOut, typename TOut,
typename InLengths, typename InLengths,
...@@ -49,7 +48,9 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -49,7 +48,9 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
const auto Y = wei_k_c_y_x_lengths[I2]; const auto Y = wei_k_c_y_x_lengths[I2];
const auto X = wei_k_c_y_x_lengths[I3]; const auto X = wei_k_c_y_x_lengths[I3];
#if 0 constexpr auto InWeiVectorSize = 8;
#if 1
const auto C0 = C / Number<InWeiVectorSize>{}; const auto C0 = C / Number<InWeiVectorSize>{};
const auto C1 = Number<InWeiVectorSize>{}; const auto C1 = Number<InWeiVectorSize>{};
......
...@@ -102,7 +102,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nhwc_kyxc_nhwk( ...@@ -102,7 +102,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nhwc_kyxc_nhwk(
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1; constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 9, 1, 8>; using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K_E2 = Sequence<1, EPerBlock, 16, 1>; using ABlockTransferThreadClusterLengths_E0_E1_K_E2 = Sequence<1, EPerBlock, 16, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2; constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
......
...@@ -11,7 +11,7 @@ template <ck::index_t BlockSize, ...@@ -11,7 +11,7 @@ template <ck::index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
ck::index_t E1, ck::index_t E1,
ck::index_t C1PerBlock, ck::index_t E2,
ck::index_t KPerBlock, ck::index_t KPerBlock,
ck::index_t HoPerBlock, ck::index_t HoPerBlock,
ck::index_t WoPerBlock, ck::index_t WoPerBlock,
...@@ -20,11 +20,11 @@ template <ck::index_t BlockSize, ...@@ -20,11 +20,11 @@ template <ck::index_t BlockSize,
ck::index_t HoPerThread, ck::index_t HoPerThread,
ck::index_t WoPerThread, ck::index_t WoPerThread,
ck::index_t EPerThread, ck::index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E0_E1_K_C1, typename ABlockTransferThreadSliceLengths_E0_E1_K_E2,
typename ABlockTransferThreadClusterLengths_E0_E1_K_C1, typename ABlockTransferThreadClusterLengths_E0_E1_K_E2,
ck::index_t ABlockTransferSrcScalarPerVector_C1, ck::index_t ABlockTransferSrcScalarPerVector_E2,
ck::index_t ABlockTransferDstScalarPerVector_C1, ck::index_t ABlockTransferDstScalarPerVector_E2,
ck::index_t BThreadTransferSrcScalarPerVector_C1, ck::index_t BThreadTransferSrcScalarPerVector_E2,
ck::index_t CThreadTransferDstScalarPerVector_K> ck::index_t CThreadTransferDstScalarPerVector_K>
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
{ {
...@@ -95,22 +95,24 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -95,22 +95,24 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
const auto E = C0 * Y * X; const auto E = C0 * Y * X;
static_assert(E2 == C1, "");
const auto E0 = E / E1; const auto E0 = E / E1;
// weight tensor // weight tensor
const auto a_e0_k_c2_grid_desc = transform_tensor_descriptor( const auto a_e0_k_e2_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, C1)), make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, E2)),
make_tuple(make_pass_through_transform(K), make_tuple(make_pass_through_transform(K),
make_pass_through_transform(C0 * Y * X), make_pass_through_transform(C0 * Y * X),
make_pass_through_transform(C1)), make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{})); make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}));
const auto a_e0_e1_k_c2_grid_desc = const auto a_e0_e1_k_e2_grid_desc =
transform_tensor_descriptor(a_e0_k_c2_grid_desc, transform_tensor_descriptor(a_e0_k_e2_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(E0, E1)), make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
make_pass_through_transform(K), make_pass_through_transform(K),
make_pass_through_transform(C1)), make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
...@@ -121,7 +123,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -121,7 +123,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
make_pass_through_transform(C0), make_pass_through_transform(C0),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C1)), make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
...@@ -132,29 +134,29 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -132,29 +134,29 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
make_pass_through_transform(C0), make_pass_through_transform(C0),
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)), make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)), make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C1)), make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{})); Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{}));
const auto b_e0_n_ho_wo_c2_grid_desc = transform_tensor_descriptor( const auto b_e0_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
in_n_c0_y_ho_x_wo_c1_global_desc, in_n_c0_y_ho_x_wo_c1_global_desc,
make_tuple(make_merge_transform(make_tuple(C0, Y, X)), make_tuple(make_merge_transform(make_tuple(C0, Y, X)),
make_pass_through_transform(N), make_pass_through_transform(N),
make_pass_through_transform(Hop), make_pass_through_transform(Hop),
make_pass_through_transform(Wop), make_pass_through_transform(Wop),
make_pass_through_transform(C1)), make_pass_through_transform(E2)),
make_tuple( make_tuple(
Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}), Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto b_e0_e1_n_ho_wo_c2_grid_desc = transform_tensor_descriptor( const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
b_e0_n_ho_wo_c2_grid_desc, b_e0_n_ho_wo_e2_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(E0, E1)), make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
make_pass_through_transform(N), make_pass_through_transform(N),
make_pass_through_transform(Hop), make_pass_through_transform(Hop),
make_pass_through_transform(Wop), make_pass_through_transform(Wop),
make_pass_through_transform(C1)), make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple( make_tuple(
Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{})); Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}));
...@@ -172,8 +174,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -172,8 +174,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl; std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 && if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
(E1 % E1PerBlock) == 0) && (E1 % E1PerBlock) == 0))
(C1 % C1PerBlock))
{ {
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
...@@ -228,11 +229,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -228,11 +229,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
decltype(a_e0_e1_k_c2_grid_desc), decltype(a_e0_e1_k_e2_grid_desc),
decltype(b_e0_e1_n_ho_wo_c2_grid_desc), decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
decltype(c_k_n_hop_wop_grid_desc), decltype(c_k_n_hop_wop_grid_desc),
E1, E1,
C1PerBlock, E2,
KPerBlock, KPerBlock,
HoPerBlock, HoPerBlock,
WoPerBlock, WoPerBlock,
...@@ -241,17 +242,17 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -241,17 +242,17 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
HoPerThread, HoPerThread,
WoPerThread, WoPerThread,
EPerThread, EPerThread,
ABlockTransferThreadSliceLengths_E0_E1_K_C1, ABlockTransferThreadSliceLengths_E0_E1_K_E2,
ABlockTransferThreadClusterLengths_E0_E1_K_C1, ABlockTransferThreadClusterLengths_E0_E1_K_E2,
Sequence<2, 0, 1, 3>, Sequence<2, 0, 1, 3>,
Sequence<2, 0, 1, 3>, Sequence<2, 0, 1, 3>,
3, 3,
ABlockTransferSrcScalarPerVector_C1, ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_C1, ABlockTransferDstScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
Sequence<0, 2, 3, 4, 1, 5>, Sequence<0, 2, 3, 4, 1, 5>,
5, 5,
BThreadTransferSrcScalarPerVector_C1, BThreadTransferSrcScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy, which will be fused with false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation // MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 1, 0>, Sequence<2, 3, 1, 0>,
...@@ -263,8 +264,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -263,8 +264,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack), decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
decltype(b_e0_e1_n_ho_wo_e2_global_move_slice_window_step_hack)>; decltype(b_e0_e1_n_ho_wo_e2_global_move_slice_window_step_hack)>;
using AGridDesc_E0_E1_K_C1 = decltype(a_e0_e1_k_c2_grid_desc); using AGridDesc_E0_E1_K_E2 = decltype(a_e0_e1_k_e2_grid_desc);
using BGridDesc_E0_E1_N_Ho_Wo_C1 = decltype(b_e0_e1_n_ho_wo_c2_grid_desc); using BGridDesc_E0_E1_N_Ho_Wo_E2 = decltype(b_e0_e1_n_ho_wo_e2_grid_desc);
using CGridDesc_K_N_Ho_Wo = decltype(c_k_n_hop_wop_grid_desc); using CGridDesc_K_N_Ho_Wo = decltype(c_k_n_hop_wop_grid_desc);
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
...@@ -294,8 +295,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -294,8 +295,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K_C1>, remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_C1>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
true, true,
...@@ -309,8 +310,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -309,8 +310,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_e0_e1_k_c2_grid_desc, a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_c2_grid_desc, b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc, c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor); c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
} }
...@@ -320,8 +321,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -320,8 +321,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K_C1>, remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_C1>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
true, true,
...@@ -335,8 +336,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -335,8 +336,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_e0_e1_k_c2_grid_desc, a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_c2_grid_desc, b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc, c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor); c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
} }
...@@ -346,8 +347,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -346,8 +347,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K_C1>, remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_C1>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false, false,
...@@ -361,8 +362,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -361,8 +362,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_e0_e1_k_c2_grid_desc, a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_c2_grid_desc, b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc, c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor); c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
} }
...@@ -372,8 +373,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -372,8 +373,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K_C1>, remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_C1>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false, false,
...@@ -387,22 +388,22 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -387,22 +388,22 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_e0_e1_k_c2_grid_desc, a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_c2_grid_desc, b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc, c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor); c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
} }
return ave_time; return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_e0_e1_k_c2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K_C1)); DeviceMem a_e0_e1_k_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K_E2));
DeviceMem b_e0_e1_n_ho_wo_c2_grid_desc_dev_buf(sizeof(BGridDesc_E0_E1_N_Ho_Wo_C1)); DeviceMem b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf(sizeof(BGridDesc_E0_E1_N_Ho_Wo_E2));
DeviceMem c_k_n_hop_wop_grid_desc_dev_buf(sizeof(CGridDesc_K_N_Ho_Wo)); DeviceMem c_k_n_hop_wop_grid_desc_dev_buf(sizeof(CGridDesc_K_N_Ho_Wo));
DeviceMem c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf( DeviceMem c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf(
sizeof(CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo)); sizeof(CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo));
a_e0_e1_k_c2_grid_desc_dev_buf.ToDevice(&a_e0_e1_k_c2_grid_desc); a_e0_e1_k_e2_grid_desc_dev_buf.ToDevice(&a_e0_e1_k_e2_grid_desc);
b_e0_e1_n_ho_wo_c2_grid_desc_dev_buf.ToDevice(&b_e0_e1_n_ho_wo_c2_grid_desc); b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.ToDevice(&b_e0_e1_n_ho_wo_e2_grid_desc);
c_k_n_hop_wop_grid_desc_dev_buf.ToDevice(&c_k_n_hop_wop_grid_desc); c_k_n_hop_wop_grid_desc_dev_buf.ToDevice(&c_k_n_hop_wop_grid_desc);
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.ToDevice( c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.ToDevice(
&c_blockid_to_k_n_ho_wo_block_cluster_adaptor); &c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
...@@ -415,8 +416,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -415,8 +416,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K_C1>, remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_C1>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
true, true,
...@@ -432,9 +433,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -432,9 +433,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
a_e0_e1_k_c2_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_c2_grid_desc_dev_buf.GetDeviceBuffer()), b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()), c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
...@@ -446,8 +447,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -446,8 +447,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K_C1>, remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_C1>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
true, true,
...@@ -463,9 +464,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -463,9 +464,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
a_e0_e1_k_c2_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_c2_grid_desc_dev_buf.GetDeviceBuffer()), b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()), c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
...@@ -477,8 +478,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -477,8 +478,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K_C1>, remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_C1>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false, false,
...@@ -494,9 +495,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -494,9 +495,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
a_e0_e1_k_c2_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_c2_grid_desc_dev_buf.GetDeviceBuffer()), b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()), c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
...@@ -508,8 +509,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -508,8 +509,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K_C1>, remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_C1>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false, false,
...@@ -525,9 +526,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -525,9 +526,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
a_e0_e1_k_c2_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_c2_grid_desc_dev_buf.GetDeviceBuffer()), b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()), c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
......
...@@ -24,7 +24,8 @@ ...@@ -24,7 +24,8 @@
#define USE_CONV_FWD_V4R4_NCHW 0 #define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V5R1_NHWC 1 #define USE_CONV_FWD_V5R1_NHWC 0
#define USE_CONV_FWD_V5R1_NCHWC 1
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 #define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
...@@ -33,9 +34,10 @@ enum ConvForwardAlgo ...@@ -33,9 +34,10 @@ enum ConvForwardAlgo
V4R4NCHW, // 0 V4R4NCHW, // 0
V4R4R2NHWC, // 1 V4R4R2NHWC, // 1
V6R1NCHW, // 2 V6R1NCHW, // 2
V5R1NHWC, // 3 V5R1NCHWc, // 3
V4R4R2XDLNCHW, // 4 V5R1NHWC, // 4
V4R4R4XDLNHWC // 5 V4R4R2XDLNCHW, // 5
V4R4R4XDLNHWC // 6
}; };
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -342,6 +344,32 @@ int main(int argc, char* argv[]) ...@@ -342,6 +344,32 @@ int main(int argc, char* argv[])
} }
#endif #endif
#if USE_CONV_FWD_V5R1_NCHWC
if(algo == ConvForwardAlgo::V5R1NCHWc)
{
if(layout != ConvTensorLayout::NCHW)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nchw();
device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw<in_data_t,
acc_data_t,
out_data_t>(tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
#if USE_CONV_FWD_V5R1_NHWC #if USE_CONV_FWD_V5R1_NHWC
if(algo == ConvForwardAlgo::V5R1NHWC) if(algo == ConvForwardAlgo::V5R1NHWC)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment