Commit 95710403 authored by Jing Zhang's avatar Jing Zhang
Browse files

add kpack with incorrect results

parent 44078dba
...@@ -64,6 +64,10 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -64,6 +64,10 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const auto InRightPadH = in_right_pads[I0]; const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1]; const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
// weight tensor // weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
...@@ -71,6 +75,13 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -71,6 +75,13 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
wei_gemmk_gemmm_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK / GemmKPack, GemmKPack)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor // input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc, in_n_c_hi_wi_global_desc,
...@@ -97,6 +108,13 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -97,6 +108,13 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
in_gemmk_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK / GemmKPack, GemmKPack)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor // output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
...@@ -104,11 +122,11 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -104,11 +122,11 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); assert(GemmM == out_gemmm_gemmn_global_desc.GetLength(I0));
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); assert(GemmN == out_gemmm_gemmn_global_desc.GetLength(I1));
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); const auto GemmK0 = wei_gemmk0_gemmm_gemmk1_global_desc.GetLength(I0);
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0); assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK0 % GemmKPerBlock == 0);
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, GemmMPerWave, GemmNPerWave, GemmKPack>{}; constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, GemmMPerWave, GemmNPerWave, GemmKPack>{};
...@@ -129,22 +147,26 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -129,22 +147,26 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{})); make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor // hack to control index calculation when iterating over wei_gemmk0_gemmm_gemmk1_global tensor
constexpr auto wei_gemmk_gemmm_global_iterator_hacks = constexpr auto wei_gemmk0_gemmm_gemmk1_global_iterator_hacks = make_tuple(
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{}; constexpr auto wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk_gemmn_global tensor // hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor
constexpr auto in_gemmk_gemmn_global_iterator_hacks = constexpr auto in_gemmk0_gemmn_gemmk1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{})); make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = constexpr auto in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global // hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format // tensor hack for NKHW format
...@@ -158,15 +180,15 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -158,15 +180,15 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 2, 0, 0>{}));
return make_tuple(wei_gemmk_gemmm_global_desc, return make_tuple(wei_gemmk0_gemmm_gemmk1_global_desc,
in_gemmk_gemmn_global_desc, in_gemmk0_gemmn_gemmk1_global_desc,
out_m0_m1_m2_n_global_desc, out_m0_m1_m2_n_global_desc,
out_gemm_block_cluster_desc, out_gemm_block_cluster_desc,
wei_gemmk_gemmm_global_iterator_hacks, wei_gemmk0_gemmm_gemmk1_global_iterator_hacks,
in_gemmk_gemmn_global_iterator_hacks, in_gemmk0_gemmn_gemmk1_global_iterator_hacks,
out_m0_m1_m2_n_global_iterator_hacks, out_m0_m1_m2_n_global_iterator_hacks,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks, wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks,
in_gemmk_gemmn_global_move_slice_window_iterator_hacks); in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks);
} }
} // namespace ck } // namespace ck
......
...@@ -73,64 +73,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -73,64 +73,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
const auto out_n_k_ho_wo_desc = make_dynamic_naive_tensor_descriptor_packed_v2( const auto out_n_k_ho_wo_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(OutDesc::GetLengths())); sequence_to_tuple_of_number(OutDesc::GetLengths()));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{}); const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif #endif
#if 0
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#else
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 4; constexpr index_t GemmKPack = 4;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 2; constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, GemmKPack>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, GemmKPack>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4; constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#endif
const auto descs = const auto descs =
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad<TInWei, transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad<TInWei,
...@@ -167,21 +141,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -167,21 +141,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
GemmKPack, GemmKPack,
MRepeat, MRepeat,
NRepeat, NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0>, Sequence<1, 0, 2>,
Sequence<1, 0>, Sequence<1, 0, 2>,
0, 2,
GemmABlockTransferSrcScalarPerVector_GemmK, GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM, GemmABlockTransferDstScalarPerVector_KPack,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<0, 1>, Sequence<0, 2, 1>,
Sequence<0, 1>, Sequence<0, 2, 1>,
1, 1,
GemmBBlockTransferSrcScalarPerVector_GemmN, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_KPack,
false, // don't move back src coordinate after threadwise copy, which will be fused with false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation // MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>, Sequence<2, 3, 0, 1>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment