Unverified Commit 3835318c authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

xdlops_v4r4_fwd fp32/fp16 (#34)



* create files for xdlops

* working on blockwise_gemm_xdlops

* add KReduction

* add m/n repeats

* add 2x2 pipeline

* added 128x128 wavegemm

* use StaticBuffer of vector_type

* break vector type to blk_size

* add kpack into xldops_gemm and blockwise_gemm

* abroadcast only

* add fp32 mfma instructions

* adding fp16 mfma

* pack half4_t

* rename kperwave to kpack

* add 32x32x8fp16

* add fp16 mfma

* clean code

* clean code

* V4r4 xdlops kpack (#35)

* add kpack with incorrect results

* bug fix for make_dynamic_naive_tensor_descriptor_aligned_v2

* add 1x1 kernel

* add gridwise_gemm_v2 - single_buffer

* enabled dwordx4 for fp16
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>

* refactor fwd-v4r4-xdlops

* add v4r4-nhwc-xdlop

* improve some perf of nhwc and nchw by tuning parameters, and change scheuduling in gridwise-gemm loop

* tweak scheduling in gridwise gemm

* add v4r3 with a single output copy

* init commit: output with slice win

* adding sliceWin

* add multiple repeats pattern

* starting adding bwd-v4r1-xdlops

* use tuple as SrcBuffer

* adding bwd-data v4r1 nhwc xdlops

* fix bug in make_dynamic_naive_tensor_descriptor_aligned_v2()

* fix bug in host bwd-data conv

* initial implementation of bwd-data v4r1 nhwc xdlops

* add launch bound flags

* enable launch bound

* add m/nrepeat=4

* tweak bwd-data v4r1 nhwc xdlops

* added bwd-data v4r1 nhwc xlops with output A and weight B

* add fwd-v4r4 nhwc xdlops, A input, B weight, C output
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 1685048a
...@@ -81,24 +81,27 @@ message("Compiling options for drivers: ${CMAKE_CXX_FLAGS}") ...@@ -81,24 +81,27 @@ message("Compiling options for drivers: ${CMAKE_CXX_FLAGS}")
if(DEVICE_BACKEND STREQUAL "AMD") if(DEVICE_BACKEND STREQUAL "AMD")
set(CONV_SOURCE driver/conv_driver.cpp) set(CONV_SOURCE driver/conv_driver.cpp)
set(CONV_BWD_DATA_SOURCE driver/conv_bwd_data_driver.cpp)
set(CONV_V2_SOURCE driver/conv_driver_v2.cpp) set(CONV_V2_SOURCE driver/conv_driver_v2.cpp)
set(CONV_BWD_DATA_V2_SOURCE driver/conv_bwd_data_driver_v2.cpp)
set(CONV_V2_OLC_SOURCE driver/conv_driver_v2_olc.cpp) set(CONV_V2_OLC_SOURCE driver/conv_driver_v2_olc.cpp)
set(CONV_BWD_DATA_SOURCE driver/conv_bwd_data_driver.cpp)
elseif(DEVICE_BACKEND STREQUAL "NVIDIA") elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
set(CONV_SOURCE driver/conv_driver.cu) set(CONV_SOURCE driver/conv_driver.cu)
set(CONV_BWD_DATA_SOURCE driver/conv_bwd_data_driver.cu) set(CONV_BWD_DATA_SOURCE driver/conv_bwd_data_driver.cu)
endif() endif()
##add_executable(conv_driver ${CONV_SOURCE}) add_executable(conv_driver ${CONV_SOURCE})
add_executable(conv_bwd_data_driver ${CONV_BWD_DATA_SOURCE})
add_executable(conv_driver_v2 ${CONV_V2_SOURCE}) add_executable(conv_driver_v2 ${CONV_V2_SOURCE})
add_executable(conv_bwd_data_driver_v2 ${CONV_BWD_DATA_V2_SOURCE})
add_executable(conv_driver_v2_olc ${CONV_V2_OLC_SOURCE}) add_executable(conv_driver_v2_olc ${CONV_V2_OLC_SOURCE})
##add_executable(conv_bwd_data_driver ${CONV_BWD_DATA_SOURCE})
target_include_directories(conv_driver_v2_olc PRIVATE driver/olCompiling/include/) target_include_directories(conv_driver_v2_olc PRIVATE driver/olCompiling/include/)
##target_link_libraries(conv_driver PRIVATE modConv) target_link_libraries(conv_driver PRIVATE modConv)
target_link_libraries(conv_bwd_data_driver PRIVATE modConv)
target_link_libraries(conv_driver_v2 PRIVATE modConv) target_link_libraries(conv_driver_v2 PRIVATE modConv)
target_link_libraries(conv_bwd_data_driver_v2 PRIVATE modConv)
target_link_libraries(conv_driver_v2_olc PRIVATE modConv) target_link_libraries(conv_driver_v2_olc PRIVATE modConv)
##target_link_libraries(conv_bwd_data_driver PRIVATE modConv)
#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "driver_dynamic_gemm_xdlops_v1.hpp"
#include "driver_dynamic_gemm_xdlops_v2.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename FloatAB,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmKPack,
typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmK0 = GemmK / GemmKPack;
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
wei_gemmk_gemmm_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_global_desc =
transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
in_gemmk_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
assert(GemmM == out_gemmm_gemmn_global_desc.GetLength(I0));
assert(GemmN == out_gemmm_gemmn_global_desc.GetLength(I1));
assert(GemmK0 == in_gemmk0_gemmn_gemmk1_global_desc.GetLength(I0));
assert(GemmK0 == wei_gemmk0_gemmm_gemmk1_global_desc.GetLength(I0));
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK0 % GemmKPerBlock == 0);
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, GemmMPerWave, GemmNPerWave, GemmKPack>{};
constexpr auto CLayout = xdlops_gemm.GetCLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM / (M1 * M2), M1, M2)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
// out_gemm_block_cluster_desc
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over wei_gemmk0_gemmm_gemmk1_global tensor
constexpr auto wei_gemmk0_gemmm_gemmk1_global_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor
constexpr auto in_gemmk0_gemmn_gemmk1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
constexpr auto out_m0_m1_m2_n_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_global_desc,
in_gemmk0_gemmn_gemmk1_global_desc,
out_m0_m1_m2_n_global_desc,
out_gemm_block_cluster_desc,
wei_gemmk0_gemmm_gemmk1_global_iterator_hacks,
in_gemmk0_gemmn_gemmk1_global_iterator_hacks,
out_m0_m1_m2_n_global_iterator_hacks,
wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks,
in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks);
}
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename FloatAB,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmKPack,
typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmK0 = GemmK / GemmKPack;
assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
InRightPadW == 0);
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
wei_gemmk_gemmm_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
in_gemmk_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
assert(GemmM == out_gemmm_gemmn_global_desc.GetLength(I0));
assert(GemmN == out_gemmm_gemmn_global_desc.GetLength(I1));
assert(GemmK0 == in_gemmk0_gemmn_gemmk1_global_desc.GetLength(I0));
assert(GemmK0 == wei_gemmk0_gemmm_gemmk1_global_desc.GetLength(I0));
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK0 % GemmKPerBlock == 0);
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, GemmMPerWave, GemmNPerWave, GemmKPack>{};
constexpr auto CLayout = xdlops_gemm.GetCLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM / (M1 * M2), M1, M2)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
// out_gemm_block_cluster_desc
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over wei_gemmk0_gemmm_gemmk1_global tensor
constexpr auto wei_gemmk0_gemmm_gemmk1_global_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor
constexpr auto in_gemmk0_gemmn_gemmk1_global_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 1, 2, 0, 0>{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
constexpr auto out_m0_m1_m2_n_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_global_desc,
in_gemmk0_gemmn_gemmk1_global_desc,
out_m0_m1_m2_n_global_desc,
out_gemm_block_cluster_desc,
wei_gemmk0_gemmm_gemmk1_global_iterator_hacks,
in_gemmk0_gemmn_gemmk1_global_iterator_hacks,
out_m0_m1_m2_n_global_iterator_hacks,
wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks,
in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks);
}
} // namespace ck
#endif
#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2
#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_xdlops_v2.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t KPack,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
__host__ float launch_kernel_dynamic_gemm_xdlops_v2(const FloatAB* p_a_global,
const FloatAB* p_b_global,
FloatC* p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
const auto K = a_k_m_global_desc.GetLength(I0);
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
if(!(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// GEMM
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGlobalDesc,
BGlobalDesc,
CGlobalDesc,
CBlockClusterDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
KPack,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks>;
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = 0;
const auto kernel = kernel_dynamic_gemm_xdlops_v2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc));
DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc));
DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc));
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc);
b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc);
c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc);
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
float ave_time = 0;
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
return ave_time;
#endif
}
} // namespace ck
#endif
#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R2
#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R2
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_xdlops_v2r2.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CMNGridDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
__host__ float driver_dynamic_gemm_xdlops_v2r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
using GridwiseGemm =
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r2<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
{
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
<< a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2)
<< "}" << std::endl;
std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", "
<< b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2)
<< "}" << std::endl;
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
{
throw std::runtime_error(
"wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2r2 has invalid setting");
}
const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
const auto kernel = kernel_dynamic_gemm_xdlops_v2r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<CM0M1M2NGridDesc>,
remove_reference_t<CBlockClusterAdaptor>>;
float ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor);
return ave_time;
}
} // namespace ck
#endif
#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3
#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CMNGridDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks,
bool CAccessOrderMRepeatNRepeat>
__host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
using GridwiseGemm =
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks,
CAccessOrderMRepeatNRepeat>;
{
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
<< a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2)
<< "}" << std::endl;
std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", "
<< b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2)
<< "}" << std::endl;
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
{
throw std::runtime_error(
"wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
}
const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
const auto kernel = kernel_dynamic_gemm_xdlops_v2r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<CM0M1M2NGridDesc>,
remove_reference_t<CBlockClusterAdaptor>>;
float ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor);
return ave_time;
}
} // namespace ck
#endif
...@@ -116,8 +116,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -116,8 +116,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice; constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
// GemmK is different for each GEMM // GemmK is different for each GEMM
index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda);
index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda);
index_t GemmK = K * YDotSlice * XDotSlice; index_t GemmK = K * YDotSlice * XDotSlice;
...@@ -176,8 +176,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -176,8 +176,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; constexpr index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda);
constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; constexpr index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda);
constexpr index_t HTilda = constexpr index_t HTilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
......
...@@ -118,8 +118,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk ...@@ -118,8 +118,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice; constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
// GemmK is different for each GEMM // GemmK is different for each GEMM
index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda);
index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda);
index_t GemmK0 = YDotSlice; index_t GemmK0 = YDotSlice;
index_t GemmK1 = XDotSlice; index_t GemmK1 = XDotSlice;
...@@ -180,8 +180,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk ...@@ -180,8 +180,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; constexpr index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda);
constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; constexpr index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda);
constexpr index_t HTilda = constexpr index_t HTilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
......
#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace ck {
// Number of GEMMs = YTilda * XTilda
// GemmM = C
// GemmN = N * HTildaSlice * WTildaSlice
// GemmK = K * YDotSlice * XDotSlice
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t IYTildaValue,
index_t IXTildaValue,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<IYTildaValue>,
Number<IXTildaValue>,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
constexpr auto IYTilda = Number<IYTildaValue>{};
constexpr auto IXTilda = Number<IXTildaValue>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilda = ConvStrideH / GcdStrideDilationH;
const auto XTilda = ConvStrideW / GcdStrideDilationW;
const auto YDot = math::integer_divide_ceil(Y, YTilda);
const auto XDot = math::integer_divide_ceil(X, XTilda);
const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
const auto IHTildaSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH);
const auto IWTildaSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW);
const auto IHTildaSliceEnd =
math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildaSliceEnd =
math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin;
const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin;
// GemmK is different for each GEMM
const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda);
const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda);
const auto K1 = GemmK1;
const auto K0 = K / K1;
// weight tensor
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_dynamic_tensor_descriptor(
wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilda),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilda),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
transform_dynamic_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(IYTilda),
make_freeze_transform(IXTilda),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0, 1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<>{},
Sequence<>{},
Sequence<4>{}));
#if 1
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif
// output tensor
// this add padding check
const auto out_n_hop_wop_k_grid_desc = transform_dynamic_tensor_descriptor(
out_n_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_dynamic_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilda),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilda),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
transform_dynamic_tensor_descriptor(
out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6>{}));
#if 1
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif
// input tensor
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilda, HTilda),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilda, WTilda),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(IYTilda),
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
make_freeze_transform(IXTilda),
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
in_n_htildaslice_wtildaslice_c_grid_desc,
make_tuple(make_pass_through_transform(C),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice))),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace ck {
// A: out
// B: wei
// C: in
// Number of GEMMs = YTilda * XTilda
// GemmM = N * HTildaSlice * WTildaSlice
// GemmN = C
// GemmK = K * YDotSlice * XDotSlice
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t IYTildaValue,
index_t IXTildaValue,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<IYTildaValue>,
Number<IXTildaValue>,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
constexpr auto IYTilda = Number<IYTildaValue>{};
constexpr auto IXTilda = Number<IXTildaValue>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilda = ConvStrideH / GcdStrideDilationH;
const auto XTilda = ConvStrideW / GcdStrideDilationW;
const auto YDot = math::integer_divide_ceil(Y, YTilda);
const auto XDot = math::integer_divide_ceil(X, XTilda);
const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
const auto IHTildaSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH);
const auto IWTildaSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW);
const auto IHTildaSliceEnd =
math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildaSliceEnd =
math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin;
const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin;
// GemmK is different for each GEMM
const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda);
const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda);
const auto K1 = GemmK1;
const auto K0 = K / K1;
// A: output tensor
// this add padding check
const auto out_n_hop_wop_k_grid_desc = transform_dynamic_tensor_descriptor(
out_n_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_dynamic_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilda),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilda),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
transform_dynamic_tensor_descriptor(
out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6>{}));
#if 1
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif
// B: weight tensor
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_dynamic_tensor_descriptor(
wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilda),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilda),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
transform_dynamic_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(IYTilda),
make_freeze_transform(IXTilda),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0, 1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<>{},
Sequence<>{},
Sequence<4>{}));
#if 1
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif
// C: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilda, HTilda),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilda, WTilda),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(IYTilda),
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
make_freeze_transform(IXTilda),
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
in_n_htildaslice_wtildaslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
...@@ -18,9 +18,9 @@ template <typename... Wei, ...@@ -18,9 +18,9 @@ template <typename... Wei,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad( __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_global_desc, const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_global_desc, const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_global_desc, const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -31,18 +31,18 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ ...@@ -31,18 +31,18 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto N = in_n_hi_wi_c_global_desc.GetLength(I0); const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_global_desc.GetLength(I3); const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_global_desc.GetLength(I3); const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1); const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2); const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1); const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2); const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_global_desc.GetLength(I1); const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_global_desc.GetLength(I2); const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0]; const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1]; const auto ConvStrideW = conv_strides[I1];
...@@ -57,15 +57,15 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ ...@@ -57,15 +57,15 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
const auto InRightPadW = in_right_pads[I1]; const auto InRightPadW = in_right_pads[I1];
// weight tensor // weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor // input tensor
const auto in_n_hip_wip_c_global_desc = transform_dynamic_tensor_descriptor( const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hi_wi_c_global_desc, in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi, InLeftPadW, InRightPadW),
...@@ -73,8 +73,8 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ ...@@ -73,8 +73,8 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_global_desc = transform_dynamic_tensor_descriptor( const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hip_wip_c_global_desc, in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
...@@ -82,22 +82,22 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ ...@@ -82,22 +82,22 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmn_global_desc = const auto in_gemmk_gemmn_grid_desc =
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_global_desc, transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)), make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor // output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple( return make_tuple(
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc);
} }
template <typename... Wei, template <typename... Wei,
...@@ -108,9 +108,9 @@ template <typename... Wei, ...@@ -108,9 +108,9 @@ template <typename... Wei,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1( __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1(
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_global_desc, const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_global_desc, const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_global_desc, const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -121,18 +121,18 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ ...@@ -121,18 +121,18 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto N = in_n_hi_wi_c_global_desc.GetLength(I0); const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_global_desc.GetLength(I3); const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_global_desc.GetLength(I3); const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1); const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2); const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1); const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2); const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_global_desc.GetLength(I1); const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_global_desc.GetLength(I2); const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0]; const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1]; const auto ConvStrideW = conv_strides[I1];
...@@ -151,28 +151,28 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ ...@@ -151,28 +151,28 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
InRightPadW == 0); InRightPadW == 0);
// weight tensor // weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor // input tensor
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( const auto in_gemmk_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, C)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, C)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)), make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
// output tensor // output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple( return make_tuple(
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc);
} }
} // namespace ck } // namespace ck
......
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmK0 = GemmK / GemmK1;
// weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
wei_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor
const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor(
in_n_c_hip_wip_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_grid_desc =
transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
in_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmK0 = GemmK / GemmK1;
// weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
wei_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmn_grid_desc =
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
in_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace ck {
// A: in
// B: wei
// C: out
// GemmM = N * Ho * Wo
// GemmN = K
// GemmK = C * Y * X
template <typename... In,
typename... Wei,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = N * Ho * Wo;
const auto GemmN = K;
const auto GemmK = Y * X * C;
const auto GemmK0 = GemmK / GemmK1;
// A: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmm_grid_desc =
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
in_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: weight tensor
const auto wei_gemmk_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
wei_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: output tensor
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
...@@ -1417,6 +1417,7 @@ struct DynamicUnMerge ...@@ -1417,6 +1417,7 @@ struct DynamicUnMerge
printf("DynamicUnMerge, "); printf("DynamicUnMerge, ");
printf("up_lengths_"); printf("up_lengths_");
print_multi_index(up_lengths_); print_multi_index(up_lengths_);
printf("up_lengths_scan_");
print_multi_index(up_lengths_scan_); print_multi_index(up_lengths_scan_);
printf("}"); printf("}");
} }
...@@ -1439,12 +1440,12 @@ struct DynamicFreeze ...@@ -1439,12 +1440,12 @@ struct DynamicFreeze
template <typename LowIdx, typename UpIdx> template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const const UpIdx& /* idx_up */) const
{ {
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 0, static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 0,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
idx_low = low_idx_; idx_low(Number<0>{}) = low_idx_;
} }
template <typename LowIdxDiff, template <typename LowIdxDiff,
...@@ -1453,9 +1454,9 @@ struct DynamicFreeze ...@@ -1453,9 +1454,9 @@ struct DynamicFreeze
typename UpIdx, typename UpIdx,
index_t Hack> index_t Hack>
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff& /* idx_diff_up */,
LowIdx& idx_low, LowIdx& /* idx_low */,
const UpIdx& idx_up_new, const UpIdx& /* idx_up_new */,
Number<Hack>) Number<Hack>)
{ {
idx_diff_low(Number<0>{}) = 0; idx_diff_low(Number<0>{}) = 0;
...@@ -1487,6 +1488,73 @@ struct DynamicFreeze ...@@ -1487,6 +1488,73 @@ struct DynamicFreeze
} }
}; };
// Insert a dangling upper dimension without lower dimension
template <typename UpperLength>
struct DynamicInsert
{
using UpLengths = decltype(make_tuple(UpperLength{}));
UpLengths up_lengths_;
__host__ __device__ constexpr DynamicInsert() = default;
__host__ __device__ constexpr DynamicInsert(const UpperLength& up_length)
: up_lengths_{make_tuple(up_length)}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 0; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr auto GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx&, const UpIdx&) const
{
static_assert(LowIdx::Size() == 0 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ static void
UpdateLowerIndex(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&, Number<Hack>)
{
static_assert(LowIdxDiff::Size() == 0 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 0 &&
UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpperLength>::value;
}
__host__ __device__ void Print() const
{
printf("DynamicInsert");
print_multi_index(up_lengths_);
}
};
template <typename VectorSize, typename UpLength> template <typename VectorSize, typename UpLength>
struct DynamicVectorize struct DynamicVectorize
{ {
...@@ -1572,5 +1640,99 @@ struct DynamicVectorize ...@@ -1572,5 +1640,99 @@ struct DynamicVectorize
} }
}; };
template <typename LowLength, typename SliceBegin, typename SliceEnd>
struct DynamicSlice
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{}));
UpLengths up_lengths_;
SliceBegin slice_begin_;
SliceEnd slice_end_;
__host__ __device__ constexpr DynamicSlice() = default;
__host__ __device__ constexpr DynamicSlice(const LowLength& low_length,
const SliceBegin& slice_begin,
const SliceEnd& slice_end)
: up_lengths_{make_tuple(slice_end - slice_begin)},
slice_begin_{slice_begin},
slice_end_{slice_end}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
idx_low(Number<0>{}) = idx_up[Number<0>{}] + slice_begin_;
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& idx_up_new,
Number<Hack>)
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = Number<0>{};
idx_diff_low(I0) = idx_diff_up[I0];
idx_low += idx_diff_low;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
template <typename UpIdx>
__host__ __device__ constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
{
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<SliceBegin>::value &&
is_known_at_compile_time<SliceEnd>::value;
}
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicSlice, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("slice_begin_ %d", index_t{slice_begin_});
printf("slice_end %d", index_t{slice_end_});
printf("}");
}
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -85,6 +85,14 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i ...@@ -85,6 +85,14 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i
return DynamicFreeze<LowerIndex>{low_idx}; return DynamicFreeze<LowerIndex>{low_idx};
} }
template <typename LowLength, typename SliceBegin, typename SliceEnd>
__host__ __device__ constexpr auto make_slice_transform(const LowLength& low_length,
const SliceBegin& slice_begin,
const SliceEnd& slice_end)
{
return DynamicSlice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
}
template <typename VectorSize, typename UpLength> template <typename VectorSize, typename UpLength>
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size, __host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
const UpLength& up_length) const UpLength& up_length)
......
...@@ -137,7 +137,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths ...@@ -137,7 +137,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths
math::multiplies_v2{}, math::multiplies_v2{},
Number<stride_n_minus_2>{}, Number<stride_n_minus_2>{},
i + I1, i + I1,
Number<N - 2>{}, Number<N - 1>{},
I1); I1);
} }
}, },
......
...@@ -184,6 +184,18 @@ struct TensorAdaptor ...@@ -184,6 +184,18 @@ struct TensorAdaptor
return get_container_subset(idx_hidden, BottomDimensionHiddenIds{}); return get_container_subset(idx_hidden, BottomDimensionHiddenIds{});
} }
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
bool is_known = true;
static_for<0, Transforms::Size(), 1>{}([&](auto i) {
is_known &=
remove_cv_t<remove_reference_t<decltype(Transforms{}[i])>>::IsKnownAtCompileTime();
});
return is_known && is_known_at_compile_time<ElementSize>::value;
}
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment