Commit c03045ce authored by Chao Liu's avatar Chao Liu
Browse files

rename

parent b2589957
...@@ -78,7 +78,7 @@ InLeftPads size 2, {1, 1, } ...@@ -78,7 +78,7 @@ InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, } InRightPads size 2, {1, 1, }
ConvStrides size 2, {2, 2, } ConvStrides size 2, {2, 2, }
ConvDilations size 2, {1, 1, } ConvDilations size 2, {1, 1, }
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
a_k0_m_k1_grid_desc{216, 256, 8} a_k0_m_k1_grid_desc{216, 256, 8}
b_k0_n_k1_grid_desc{216, 165888, 8} b_k0_n_k1_grid_desc{216, 165888, 8}
c_m_n_grid_desc{ 256, 165888} c_m_n_grid_desc{ 256, 165888}
...@@ -100,7 +100,7 @@ InLeftPads size 2, {1, 1, } ...@@ -100,7 +100,7 @@ InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, } InRightPads size 2, {1, 1, }
ConvStrides size 2, {1, 1, } ConvStrides size 2, {1, 1, }
ConvDilations size 2, {1, 1, } ConvDilations size 2, {1, 1, }
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
a_k0_m_k1_grid_desc{288, 1024, 8} a_k0_m_k1_grid_desc{288, 1024, 8}
b_k0_n_k1_grid_desc{288, 50176, 8} b_k0_n_k1_grid_desc{288, 50176, 8}
c_m_n_grid_desc{ 1024, 50176} c_m_n_grid_desc{ 1024, 50176}
...@@ -122,7 +122,7 @@ InLeftPads size 2, {1, 1, } ...@@ -122,7 +122,7 @@ InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, } InRightPads size 2, {1, 1, }
ConvStrides size 2, {2, 2, } ConvStrides size 2, {2, 2, }
ConvDilations size 2, {1, 1, } ConvDilations size 2, {1, 1, }
device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
a_k0_m_k1_grid_desc{216, 165888, 8} a_k0_m_k1_grid_desc{216, 165888, 8}
b_k0_n_k1_grid_desc{216, 256, 8} b_k0_n_k1_grid_desc{216, 256, 8}
c_m_n_grid_desc{ 165888, 256} c_m_n_grid_desc{ 165888, 256}
...@@ -144,7 +144,7 @@ InLeftPads size 2, {1, 1, } ...@@ -144,7 +144,7 @@ InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, } InRightPads size 2, {1, 1, }
ConvStrides size 2, {1, 1, } ConvStrides size 2, {1, 1, }
ConvDilations size 2, {1, 1, } ConvDilations size 2, {1, 1, }
device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
a_k0_m_k1_grid_desc{288, 50176, 8} a_k0_m_k1_grid_desc{288, 50176, 8}
b_k0_n_k1_grid_desc{288, 1024, 8} b_k0_n_k1_grid_desc{288, 1024, 8}
c_m_n_grid_desc{ 50176, 1024} c_m_n_grid_desc{ 50176, 1024}
...@@ -166,7 +166,7 @@ InLeftPads size 2, {1, 1, } ...@@ -166,7 +166,7 @@ InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, } InRightPads size 2, {1, 1, }
ConvStrides size 2, {1, 1, } ConvStrides size 2, {1, 1, }
ConvDilations size 2, {1, 1, } ConvDilations size 2, {1, 1, }
device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
a_k0_m_k1_grid_desc{288, 50176, 8} a_k0_m_k1_grid_desc{288, 50176, 8}
b_k0_n_k1_grid_desc{288, 1024, 8} b_k0_n_k1_grid_desc{288, 1024, 8}
c_m_n_grid_desc{ 50176, 1024} c_m_n_grid_desc{ 50176, 1024}
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
namespace ck { namespace ck {
...@@ -23,9 +23,9 @@ template <typename... Wei, ...@@ -23,9 +23,9 @@ template <typename... Wei,
index_t GemmK1Value> index_t GemmK1Value>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc, const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc, const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc, const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -102,7 +102,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -102,7 +102,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
const auto K0 = K / K1; const auto K0 = K / K1;
// weight tensor // weight tensor
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_dynamic_tensor_descriptor( const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor(
wei_k_y_x_c_grid_desc, wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K), make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilda), make_embed_transform(make_tuple(YDot, YTilda),
...@@ -114,28 +114,28 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -114,28 +114,28 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
transform_dynamic_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(IYTilda), make_freeze_transform(IYTilda),
make_freeze_transform(IXTilda), make_freeze_transform(IXTilda),
make_pass_through_transform(C)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<3>{}, Sequence<3>{},
Sequence<2>{}, Sequence<2>{},
Sequence<4>{}, Sequence<4>{},
Sequence<5>{}), Sequence<5>{}),
make_tuple(Sequence<0, 1>{}, make_tuple(Sequence<0, 1>{},
Sequence<2>{}, Sequence<2>{},
Sequence<3>{}, Sequence<3>{},
Sequence<>{}, Sequence<>{},
Sequence<>{}, Sequence<>{},
Sequence<4>{})); Sequence<4>{}));
#if 1 #if 1
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc, wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_pass_through_transform(C), make_pass_through_transform(C),
...@@ -143,7 +143,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -143,7 +143,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else #else
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc, wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_pass_through_transform(C), make_pass_through_transform(C),
...@@ -154,7 +154,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -154,7 +154,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
// output tensor // output tensor
// this add padding check // this add padding check
const auto out_n_hop_wop_k_grid_desc = transform_dynamic_tensor_descriptor( const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_k_grid_desc, out_n_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pad_transform(Ho, I0, I0), make_pad_transform(Ho, I0, I0),
...@@ -163,7 +163,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -163,7 +163,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_dynamic_tensor_descriptor( const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc, out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilda), make_embed_transform(make_tuple(YDot, HTilda),
...@@ -175,7 +175,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -175,7 +175,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc = const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
transform_dynamic_tensor_descriptor( transform_tensor_descriptor(
out_n_ydot_htilda_xdot_wtilda_k_grid_desc, out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot, I0, YDotSlice),
...@@ -197,7 +197,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -197,7 +197,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
Sequence<5, 6>{})); Sequence<5, 6>{}));
#if 1 #if 1
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
...@@ -205,7 +205,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -205,7 +205,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else #else
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
...@@ -215,7 +215,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -215,7 +215,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
#endif #endif
// input tensor // input tensor
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc, in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi, InLeftPadH, InRightPadH),
...@@ -224,7 +224,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -224,7 +224,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_dynamic_tensor_descriptor( const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc, in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilda, HTilda), make_embed_transform(make_tuple(YTilda, HTilda),
...@@ -235,7 +235,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -235,7 +235,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_dynamic_tensor_descriptor( const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_freeze_transform(IYTilda), make_freeze_transform(IYTilda),
...@@ -256,7 +256,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( ...@@ -256,7 +256,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
Sequence<2>{}, Sequence<2>{},
Sequence<3>{})); Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_htildaslice_wtildaslice_c_grid_desc, in_n_htildaslice_wtildaslice_c_grid_desc,
make_tuple(make_pass_through_transform(C), make_tuple(make_pass_through_transform(C),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice))), make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice))),
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
namespace ck { namespace ck {
...@@ -26,9 +26,9 @@ template <typename... Wei, ...@@ -26,9 +26,9 @@ template <typename... Wei,
index_t GemmK1Value> index_t GemmK1Value>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc, const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc, const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc, const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -106,7 +106,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -106,7 +106,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
// A: output tensor // A: output tensor
// this add padding check // this add padding check
const auto out_n_hop_wop_k_grid_desc = transform_dynamic_tensor_descriptor( const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_k_grid_desc, out_n_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pad_transform(Ho, I0, I0), make_pad_transform(Ho, I0, I0),
...@@ -115,7 +115,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -115,7 +115,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_dynamic_tensor_descriptor( const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc, out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilda), make_embed_transform(make_tuple(YDot, HTilda),
...@@ -127,7 +127,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -127,7 +127,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc = const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
transform_dynamic_tensor_descriptor( transform_tensor_descriptor(
out_n_ydot_htilda_xdot_wtilda_k_grid_desc, out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot, I0, YDotSlice),
...@@ -149,7 +149,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -149,7 +149,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
Sequence<5, 6>{})); Sequence<5, 6>{}));
#if 1 #if 1
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
...@@ -157,7 +157,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -157,7 +157,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else #else
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
...@@ -167,7 +167,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -167,7 +167,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
#endif #endif
// B: weight tensor // B: weight tensor
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_dynamic_tensor_descriptor( const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor(
wei_k_y_x_c_grid_desc, wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K), make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilda), make_embed_transform(make_tuple(YDot, YTilda),
...@@ -179,28 +179,28 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -179,28 +179,28 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
transform_dynamic_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(IYTilda), make_freeze_transform(IYTilda),
make_freeze_transform(IXTilda), make_freeze_transform(IXTilda),
make_pass_through_transform(C)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<3>{}, Sequence<3>{},
Sequence<2>{}, Sequence<2>{},
Sequence<4>{}, Sequence<4>{},
Sequence<5>{}), Sequence<5>{}),
make_tuple(Sequence<0, 1>{}, make_tuple(Sequence<0, 1>{},
Sequence<2>{}, Sequence<2>{},
Sequence<3>{}, Sequence<3>{},
Sequence<>{}, Sequence<>{},
Sequence<>{}, Sequence<>{},
Sequence<4>{})); Sequence<4>{}));
#if 1 #if 1
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc, wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_pass_through_transform(C), make_pass_through_transform(C),
...@@ -208,7 +208,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -208,7 +208,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else #else
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc, wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_pass_through_transform(C), make_pass_through_transform(C),
...@@ -218,7 +218,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -218,7 +218,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
#endif #endif
// C: input tensor // C: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc, in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi, InLeftPadH, InRightPadH),
...@@ -227,7 +227,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -227,7 +227,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_dynamic_tensor_descriptor( const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc, in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilda, HTilda), make_embed_transform(make_tuple(YTilda, HTilda),
...@@ -238,7 +238,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -238,7 +238,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_dynamic_tensor_descriptor( const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_freeze_transform(IYTilda), make_freeze_transform(IYTilda),
...@@ -259,7 +259,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -259,7 +259,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
Sequence<2>{}, Sequence<2>{},
Sequence<3>{})); Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_htildaslice_wtildaslice_c_grid_desc, in_n_htildaslice_wtildaslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
make_pass_through_transform(C)), make_pass_through_transform(C)),
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
namespace ck { namespace ck {
...@@ -18,9 +18,9 @@ template <typename... Wei, ...@@ -18,9 +18,9 @@ template <typename... Wei,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad( __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc, const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc, const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -57,14 +57,14 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ ...@@ -57,14 +57,14 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
const auto InRightPadW = in_right_pads[I1]; const auto InRightPadW = in_right_pads[I1];
// weight tensor // weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor // input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc, in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C), make_pass_through_transform(C),
...@@ -73,7 +73,7 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ ...@@ -73,7 +73,7 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc, in_n_c_hip_wip_global_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C), make_pass_through_transform(C),
...@@ -83,15 +83,15 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ ...@@ -83,15 +83,15 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_global_desc = const auto in_gemmk_gemmn_global_desc =
transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)), make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor // output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -109,9 +109,9 @@ template <typename... Wei, ...@@ -109,9 +109,9 @@ template <typename... Wei,
typename InRightPads> typename InRightPads>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad( transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc, const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc, const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -147,14 +147,14 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad( ...@@ -147,14 +147,14 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0); assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0);
// weight tensor // weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor // input tensor
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc, in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C), make_pass_through_transform(C),
...@@ -164,15 +164,15 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad( ...@@ -164,15 +164,15 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_global_desc = const auto in_gemmk_gemmn_global_desc =
transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)), make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor // output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -189,9 +189,9 @@ template <typename... Wei, ...@@ -189,9 +189,9 @@ template <typename... Wei,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1( __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc, const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc, const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -229,22 +229,22 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ ...@@ -229,22 +229,22 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
InRightPadW == 0); InRightPadW == 0);
// weight tensor // weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor // input tensor
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( const auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc, in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))), make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor // output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
namespace ck { namespace ck {
...@@ -18,9 +18,9 @@ template <typename... Wei, ...@@ -18,9 +18,9 @@ template <typename... Wei,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad( __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc, const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc, const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc, const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -57,14 +57,14 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ ...@@ -57,14 +57,14 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
const auto InRightPadW = in_right_pads[I1]; const auto InRightPadW = in_right_pads[I1];
// weight tensor // weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor // input tensor
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc, in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi, InLeftPadH, InRightPadH),
...@@ -73,7 +73,7 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ ...@@ -73,7 +73,7 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor( const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc, in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
...@@ -83,15 +83,15 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ ...@@ -83,15 +83,15 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmn_grid_desc = const auto in_gemmk_gemmn_grid_desc =
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)), make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor // output tensor
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
...@@ -108,9 +108,9 @@ template <typename... Wei, ...@@ -108,9 +108,9 @@ template <typename... Wei,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1( __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1(
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc, const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc, const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc, const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -148,22 +148,22 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ ...@@ -148,22 +148,22 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
InRightPadW == 0); InRightPadW == 0);
// weight tensor // weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor // input tensor
const auto in_gemmk_gemmn_grid_desc = transform_dynamic_tensor_descriptor( const auto in_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, C)), make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)), make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
// output tensor // output tensor
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
namespace ck { namespace ck {
...@@ -20,9 +20,9 @@ template <typename... Wei, ...@@ -20,9 +20,9 @@ template <typename... Wei,
index_t GemmK1Value> index_t GemmK1Value>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc, const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc, const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc, const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -67,21 +67,21 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( ...@@ -67,21 +67,21 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
const auto GemmK0 = GemmK / GemmK1; const auto GemmK0 = GemmK / GemmK1;
// weight tensor // weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk0_gemmm_gemmk1_grid_desc =
wei_gemmk_gemmm_grid_desc, transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)), make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor // input tensor
const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor( const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
in_n_c_hi_wi_grid_desc, in_n_c_hi_wi_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C), make_pass_through_transform(C),
...@@ -90,7 +90,7 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( ...@@ -90,7 +90,7 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor( const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
in_n_c_hip_wip_grid_desc, in_n_c_hip_wip_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C), make_pass_through_transform(C),
...@@ -100,22 +100,22 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( ...@@ -100,22 +100,22 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_grid_desc = const auto in_gemmk_gemmn_grid_desc =
transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)), make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( const auto in_gemmk0_gemmn_gemmk1_grid_desc =
in_gemmk_gemmn_grid_desc, transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)), make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor // output tensor
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
namespace ck { namespace ck {
...@@ -20,9 +20,9 @@ template <typename... Wei, ...@@ -20,9 +20,9 @@ template <typename... Wei,
index_t GemmK1Value> index_t GemmK1Value>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad( transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc, const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc, const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc, const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -67,21 +67,21 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad( ...@@ -67,21 +67,21 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
const auto GemmK0 = GemmK / GemmK1; const auto GemmK0 = GemmK / GemmK1;
// weight tensor // weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk0_gemmm_gemmk1_grid_desc =
wei_gemmk_gemmm_grid_desc, transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)), make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor // input tensor
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc, in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi, InLeftPadH, InRightPadH),
...@@ -90,7 +90,7 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad( ...@@ -90,7 +90,7 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor( const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc, in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
...@@ -100,22 +100,22 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad( ...@@ -100,22 +100,22 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmn_grid_desc = const auto in_gemmk_gemmn_grid_desc =
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)), make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( const auto in_gemmk0_gemmn_gemmk1_grid_desc =
in_gemmk_gemmn_grid_desc, transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)), make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor // output tensor
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
namespace ck { namespace ck {
...@@ -23,9 +23,9 @@ template <typename... In, ...@@ -23,9 +23,9 @@ template <typename... In,
index_t GemmK1Value> index_t GemmK1Value>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc, const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc, const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc, const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -70,7 +70,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( ...@@ -70,7 +70,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const auto GemmK0 = GemmK / GemmK1; const auto GemmK0 = GemmK / GemmK1;
// A: input tensor // A: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc, in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi, InLeftPadH, InRightPadH),
...@@ -79,7 +79,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( ...@@ -79,7 +79,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor( const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc, in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
...@@ -89,36 +89,36 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( ...@@ -89,36 +89,36 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmm_grid_desc = const auto in_gemmk_gemmm_grid_desc =
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)), make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( const auto in_gemmk0_gemmm_gemmk1_grid_desc =
in_gemmk_gemmm_grid_desc, transform_tensor_descriptor(in_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)), make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: weight tensor // B: weight tensor
const auto wei_gemmk_gemmn_grid_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
wei_gemmk_gemmn_grid_desc, transform_tensor_descriptor(wei_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)), make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: output tensor // C: output tensor
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
namespace ck { namespace ck {
...@@ -24,9 +24,9 @@ template <typename... Wei, ...@@ -24,9 +24,9 @@ template <typename... Wei,
typename C0Type> typename C0Type>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc, const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc, const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc, const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -68,15 +68,15 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( ...@@ -68,15 +68,15 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
const auto C1 = C / C0; const auto C1 = C / C0;
// weight tensor // weight tensor
const auto wei_gk0_gm0_gm1_gk1_grid_desc = transform_dynamic_tensor_descriptor( const auto wei_gk0_gm0_gm1_gk1_grid_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_unmerge_transform(make_tuple(I1, K)), make_tuple(make_unmerge_transform(make_tuple(I1, K)),
make_unmerge_transform(make_tuple(C0, C1 * Y * X))), make_unmerge_transform(make_tuple(C0, C1 * Y * X))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{})); make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{}));
// input tensor // input tensor
const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor( const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
in_n_c_hi_wi_grid_desc, in_n_c_hi_wi_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C), make_pass_through_transform(C),
...@@ -85,7 +85,7 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( ...@@ -85,7 +85,7 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor( const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
in_n_c_hip_wip_grid_desc, in_n_c_hip_wip_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(N0, N1)), make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
make_unmerge_transform(make_tuple(C0, C1)), make_unmerge_transform(make_tuple(C0, C1)),
...@@ -94,7 +94,7 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( ...@@ -94,7 +94,7 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_dynamic_tensor_descriptor( const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_tensor_descriptor(
in_n0_n1_c0_c1_y_ho_x_wo_grid_desc, in_n0_n1_c0_c1_y_ho_x_wo_grid_desc,
make_tuple(make_merge_transform(make_tuple(C1, Y, X)), make_tuple(make_merge_transform(make_tuple(C1, Y, X)),
make_pass_through_transform(N0), make_pass_through_transform(N0),
...@@ -105,17 +105,17 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( ...@@ -105,17 +105,17 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
// output tensor // output tensor
const auto out_n_k_howo_grid_desc = const auto out_n_k_howo_grid_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)); make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo));
const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor( const auto out_n0_n1_1_k_howo_grid_desc =
out_n_k_howo_grid_desc, transform_tensor_descriptor(out_n_k_howo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(N0, N1)), make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
make_unmerge_transform(make_tuple(I1, K)), make_unmerge_transform(make_tuple(I1, K)),
make_pass_through_transform(Ho * Wo)), make_pass_through_transform(Ho * Wo)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{}));
const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor( const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_tensor_descriptor(
out_n0_n1_1_k_howo_grid_desc, out_n0_n1_1_k_howo_grid_desc,
make_tuple(make_pass_through_transform(I1), make_tuple(make_pass_through_transform(I1),
make_pass_through_transform(K), make_pass_through_transform(K),
......
#ifndef CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HPP #ifndef CK_MULTI_INDEX_TRANSFORM_HPP
#define CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HPP #define CK_MULTI_INDEX_TRANSFORM_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "multi_index.hpp" #include "multi_index.hpp"
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace ck { namespace ck {
template <typename LowLength> template <typename LowLength>
struct DynamicPassThrough struct PassThrough
{ {
using LowerIndex = MultiIndex<1>; using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>;
...@@ -16,9 +16,9 @@ struct DynamicPassThrough ...@@ -16,9 +16,9 @@ struct DynamicPassThrough
UpLengths up_lengths_; UpLengths up_lengths_;
__host__ __device__ constexpr DynamicPassThrough() = default; __host__ __device__ constexpr PassThrough() = default;
__host__ __device__ constexpr DynamicPassThrough(const LowLength& low_length) __host__ __device__ constexpr PassThrough(const LowLength& low_length)
: up_lengths_{make_tuple(low_length)} : up_lengths_{make_tuple(low_length)}
{ {
} }
...@@ -82,33 +82,36 @@ struct DynamicPassThrough ...@@ -82,33 +82,36 @@ struct DynamicPassThrough
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
printf("DynamicPassThrough, "); printf("PassThrough, ");
printf("up_lengths_"); printf("up_lengths_");
print_multi_index(up_lengths_); print_multi_index(up_lengths_);
printf("}"); printf("}");
} }
}; };
template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false> template <typename LowLength,
struct DynamicPad typename LeftPadLength,
typename RightPadLength,
bool SkipIsValidCheck = false>
struct Pad
{ {
using LowerIndex = MultiIndex<1>; using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>;
using UpLengths = decltype(make_tuple(LowLength{} + LeftPad{} + RightPad{})); using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{} + RightPadLength{}));
UpLengths up_lengths_; UpLengths up_lengths_;
LeftPad left_pad_; LeftPadLength left_pad_length_;
RightPad right_pad_; RightPadLength right_pad_length_;
__host__ __device__ constexpr DynamicPad() = default; __host__ __device__ constexpr Pad() = default;
__host__ __device__ constexpr DynamicPad(const LowLength& low_length, __host__ __device__ constexpr Pad(const LowLength& low_length,
const LeftPad& left_pad, const LeftPadLength& left_pad_length,
const RightPad& right_pad) const RightPadLength& right_pad_length)
: up_lengths_{make_tuple(low_length + left_pad + right_pad)}, : up_lengths_{make_tuple(low_length + left_pad_length + right_pad_length)},
left_pad_{left_pad}, left_pad_length_{left_pad_length},
right_pad_{right_pad} right_pad_length_{right_pad_length}
{ {
} }
...@@ -125,7 +128,7 @@ struct DynamicPad ...@@ -125,7 +128,7 @@ struct DynamicPad
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_; idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_;
} }
template <typename LowIdxDiff, template <typename LowIdxDiff,
...@@ -161,45 +164,46 @@ struct DynamicPad ...@@ -161,45 +164,46 @@ struct DynamicPad
__host__ __device__ constexpr bool __host__ __device__ constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
{ {
return SkipIsValidCheck || ((idx_up[Number<0>{}] >= left_pad_) && return SkipIsValidCheck ||
(idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_)); ((idx_up[Number<0>{}] >= left_pad_length_) &&
(idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_length_));
} }
__host__ __device__ static constexpr bool IsKnownAtCompileTime() __host__ __device__ static constexpr bool IsKnownAtCompileTime()
{ {
return is_known_at_compile_time<UpLengths>::value && return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<LeftPad>::value && is_known_at_compile_time<LeftPadLength>::value &&
is_known_at_compile_time<RightPad>::value; is_known_at_compile_time<RightPadLength>::value;
} }
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
printf("DynamicPad, "); printf("Pad, ");
printf("up_lengths_"); printf("up_lengths_");
print_multi_index(up_lengths_); print_multi_index(up_lengths_);
printf("left_pad_ %d", index_t{left_pad_}); printf("left_pad_length %d", index_t{left_pad_length_});
printf("right_pad_ %d", index_t{right_pad_}); printf("right_pad_length %d", index_t{right_pad_length_});
printf("}"); printf("}");
} }
}; };
template <typename LowLength, typename LeftPad, bool SkipIsValidCheck = false> template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
struct DynamicLeftPad struct LeftPad
{ {
using LowerIndex = MultiIndex<1>; using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>;
using UpLengths = decltype(make_tuple(LowLength{} + LeftPad{})); using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{}));
UpLengths up_lengths_; UpLengths up_lengths_;
LeftPad left_pad_; LeftPadLength left_pad_length_;
__host__ __device__ constexpr DynamicLeftPad() = default; __host__ __device__ constexpr LeftPad() = default;
__host__ __device__ constexpr DynamicLeftPad(const LowLength& low_length, __host__ __device__ constexpr LeftPad(const LowLength& low_length,
const LeftPad& left_pad) const LeftPadLength& left_pad_length)
: up_lengths_{make_tuple(low_length + left_pad)}, left_pad_{left_pad} : up_lengths_{make_tuple(low_length + left_pad_length)}, left_pad_length_{left_pad_length}
{ {
} }
...@@ -216,7 +220,7 @@ struct DynamicLeftPad ...@@ -216,7 +220,7 @@ struct DynamicLeftPad
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_; idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_;
} }
template <typename LowIdxDiff, template <typename LowIdxDiff,
...@@ -252,45 +256,45 @@ struct DynamicLeftPad ...@@ -252,45 +256,45 @@ struct DynamicLeftPad
__host__ __device__ constexpr bool __host__ __device__ constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
{ {
return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_); return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_length_);
} }
__host__ __device__ static constexpr bool IsKnownAtCompileTime() __host__ __device__ static constexpr bool IsKnownAtCompileTime()
{ {
return is_known_at_compile_time<UpLengths>::value && return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<LeftPad>::value; is_known_at_compile_time<LeftPadLength>::value;
} }
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
printf("DynamicLeftPad, "); printf("LeftPad, ");
printf("up_lengths_"); printf("up_lengths_");
print_multi_index(up_lengths_); print_multi_index(up_lengths_);
printf("left_pad_ %d", index_t{left_pad_}); printf("left_pad_length_ %d", index_t{left_pad_length_});
printf("}"); printf("}");
} }
}; };
template <typename LowLength, typename RightPad, bool SkipIsValidCheck = false> template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
struct DynamicRightPad struct RightPad
{ {
using LowerIndex = MultiIndex<1>; using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>;
using UpLengths = decltype(make_tuple(LowLength{} + RightPad{})); using UpLengths = decltype(make_tuple(LowLength{} + RightPadLength{}));
UpLengths up_lengths_; UpLengths up_lengths_;
LowLength low_length_; LowLength low_length_;
RightPad right_pad_; RightPadLength right_pad_length_;
__host__ __device__ constexpr DynamicRightPad() = default; __host__ __device__ constexpr RightPad() = default;
__host__ __device__ constexpr DynamicRightPad(const LowLength& low_length, __host__ __device__ constexpr RightPad(const LowLength& low_length,
const RightPad& right_pad) const RightPadLength& right_pad_length)
: up_lengths_{make_tuple(low_length + right_pad)}, : up_lengths_{make_tuple(low_length + right_pad_length)},
low_length_{low_length}, low_length_{low_length},
right_pad_{right_pad} right_pad_length_{right_pad_length}
{ {
} }
...@@ -350,17 +354,17 @@ struct DynamicRightPad ...@@ -350,17 +354,17 @@ struct DynamicRightPad
{ {
return is_known_at_compile_time<UpLengths>::value && return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<LowLength>::value && is_known_at_compile_time<LowLength>::value &&
is_known_at_compile_time<RightPad>::value; is_known_at_compile_time<RightPadLength>::value;
} }
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
printf("DynamicRightPad, "); printf("RightPad, ");
printf("up_lengths_"); printf("up_lengths_");
print_multi_index(up_lengths_); print_multi_index(up_lengths_);
printf("low_length_ %d", index_t{low_length_}); printf("low_length_ %d", index_t{low_length_});
printf("left_pad_ %d", index_t{right_pad_}); printf("left_pad_length_ %d", index_t{right_pad_length_});
printf("}"); printf("}");
} }
}; };
...@@ -374,7 +378,7 @@ struct DynamicRightPad ...@@ -374,7 +378,7 @@ struct DynamicRightPad
template <typename UpLengths, template <typename UpLengths,
typename Coefficients, typename Coefficients,
typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false> typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
struct DynamicEmbed struct Embed
{ {
static constexpr index_t NDimUp = UpLengths::Size(); static constexpr index_t NDimUp = UpLengths::Size();
...@@ -384,10 +388,10 @@ struct DynamicEmbed ...@@ -384,10 +388,10 @@ struct DynamicEmbed
UpLengths up_lengths_; UpLengths up_lengths_;
Coefficients coefficients_; Coefficients coefficients_;
__host__ __device__ constexpr DynamicEmbed() = default; __host__ __device__ constexpr Embed() = default;
__host__ __device__ constexpr DynamicEmbed(const UpLengths& up_lengths, __host__ __device__ constexpr Embed(const UpLengths& up_lengths,
const Coefficients& coefficients) const Coefficients& coefficients)
: up_lengths_{up_lengths}, coefficients_{coefficients} : up_lengths_{up_lengths}, coefficients_{coefficients}
{ {
} }
...@@ -458,7 +462,7 @@ struct DynamicEmbed ...@@ -458,7 +462,7 @@ struct DynamicEmbed
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
printf("DynamicEmbed, "); printf("Embed, ");
printf("up_lengths_ "); printf("up_lengths_ ");
print_multi_index(up_lengths_); print_multi_index(up_lengths_);
printf("coefficients_ "); printf("coefficients_ ");
...@@ -470,7 +474,7 @@ struct DynamicEmbed ...@@ -470,7 +474,7 @@ struct DynamicEmbed
// Implementation of "Merge" transformation primitive that uses regular to do lowering of // Implementation of "Merge" transformation primitive that uses regular to do lowering of
// multi-index and use carry-and-borrow check to do lowering of multi-index delta // multi-index and use carry-and-borrow check to do lowering of multi-index delta
template <typename LowLengths> template <typename LowLengths>
struct DynamicMerge_v1_carry_check struct Merge_v1_carry_check
{ {
static constexpr index_t NDimLow = LowLengths::Size(); static constexpr index_t NDimLow = LowLengths::Size();
...@@ -487,9 +491,9 @@ struct DynamicMerge_v1_carry_check ...@@ -487,9 +491,9 @@ struct DynamicMerge_v1_carry_check
LowLengthsScan low_lengths_scan_; LowLengthsScan low_lengths_scan_;
UpLengths up_lengths_; UpLengths up_lengths_;
__host__ __device__ constexpr DynamicMerge_v1_carry_check() = default; __host__ __device__ constexpr Merge_v1_carry_check() = default;
__host__ __device__ constexpr DynamicMerge_v1_carry_check(const LowLengths& low_lengths) __host__ __device__ constexpr Merge_v1_carry_check(const LowLengths& low_lengths)
: low_lengths_{low_lengths}, : low_lengths_{low_lengths},
low_lengths_scan_{ low_lengths_scan_{
container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})}, container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})},
...@@ -555,7 +559,7 @@ struct DynamicMerge_v1_carry_check ...@@ -555,7 +559,7 @@ struct DynamicMerge_v1_carry_check
LowerIndex idx_low_length_minus_idx_diff_low_const; LowerIndex idx_low_length_minus_idx_diff_low_const;
LowerIndex idx_low_length_plus_idx_diff_low_const; LowerIndex idx_low_length_plus_idx_diff_low_const;
#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE #if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
index_t tmp = idx_diff_up[Number<0>{}]; index_t tmp = idx_diff_up[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&](auto i) { static_for<0, NDimLow - 1, 1>{}([&](auto i) {
...@@ -698,7 +702,7 @@ struct DynamicMerge_v1_carry_check ...@@ -698,7 +702,7 @@ struct DynamicMerge_v1_carry_check
LowerIndex idx_low_length_minus_idx_diff_low_const; LowerIndex idx_low_length_minus_idx_diff_low_const;
LowerIndex idx_low_length_plus_idx_diff_low_const; LowerIndex idx_low_length_plus_idx_diff_low_const;
#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE #if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
index_t tmp = idx_diff_up[Number<0>{}]; index_t tmp = idx_diff_up[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&](auto i) { static_for<0, NDimLow - 1, 1>{}([&](auto i) {
...@@ -838,7 +842,7 @@ struct DynamicMerge_v1_carry_check ...@@ -838,7 +842,7 @@ struct DynamicMerge_v1_carry_check
// very expensive. // very expensive.
LowerIndex idx_diff_low_const; LowerIndex idx_diff_low_const;
#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE #if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
index_t tmp = idx_diff_up[Number<0>{}]; index_t tmp = idx_diff_up[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&](auto i) { static_for<0, NDimLow - 1, 1>{}([&](auto i) {
...@@ -981,7 +985,7 @@ struct DynamicMerge_v1_carry_check ...@@ -981,7 +985,7 @@ struct DynamicMerge_v1_carry_check
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
printf("DynamicMerge_v1_carry_check, "); printf("Merge_v1_carry_check, ");
printf("low_lengths_ "); printf("low_lengths_ ");
print_multi_index(low_lengths_); print_multi_index(low_lengths_);
printf("low_lengths_scan_ "); printf("low_lengths_scan_ ");
...@@ -1025,7 +1029,7 @@ struct lambda_merge_generate_MagicDivision_calculate_magic_shift ...@@ -1025,7 +1029,7 @@ struct lambda_merge_generate_MagicDivision_calculate_magic_shift
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be // 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
// non-negative. // non-negative.
template <typename LowLengths> template <typename LowLengths>
struct DynamicMerge_v2_magic_division struct Merge_v2_magic_division
{ {
static constexpr index_t NDimLow = LowLengths::Size(); static constexpr index_t NDimLow = LowLengths::Size();
...@@ -1048,9 +1052,9 @@ struct DynamicMerge_v2_magic_division ...@@ -1048,9 +1052,9 @@ struct DynamicMerge_v2_magic_division
LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_; LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_;
UpLengths up_lengths_; UpLengths up_lengths_;
__host__ __device__ constexpr DynamicMerge_v2_magic_division() = default; __host__ __device__ constexpr Merge_v2_magic_division() = default;
__host__ __device__ constexpr DynamicMerge_v2_magic_division(const LowLengths& low_lengths) __host__ __device__ constexpr Merge_v2_magic_division(const LowLengths& low_lengths)
: low_lengths_{low_lengths}, : low_lengths_{low_lengths},
low_lengths_magic_divisor_multiplier_{generate_tuple( low_lengths_magic_divisor_multiplier_{generate_tuple(
[&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); }, [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); },
...@@ -1151,7 +1155,7 @@ struct DynamicMerge_v2_magic_division ...@@ -1151,7 +1155,7 @@ struct DynamicMerge_v2_magic_division
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
printf("DynamicMerge_v2_magic_division, "); printf("Merge_v2_magic_division, ");
printf("low_lengths_ "); printf("low_lengths_ ");
print_multi_index(low_lengths_); print_multi_index(low_lengths_);
printf("low_lengths_magic_divisor_multiplier_ "); printf("low_lengths_magic_divisor_multiplier_ ");
...@@ -1177,7 +1181,7 @@ struct DynamicMerge_v2_magic_division ...@@ -1177,7 +1181,7 @@ struct DynamicMerge_v2_magic_division
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be // 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
// non-negative. // non-negative.
template <typename LowLengths> template <typename LowLengths>
struct DynamicMerge_v2r2_magic_division struct Merge_v2r2_magic_division
{ {
static constexpr index_t NDimLow = LowLengths::Size(); static constexpr index_t NDimLow = LowLengths::Size();
...@@ -1204,9 +1208,9 @@ struct DynamicMerge_v2r2_magic_division ...@@ -1204,9 +1208,9 @@ struct DynamicMerge_v2r2_magic_division
LowLengthsScanMagicDivisorShift low_lengths_scan_magic_divisor_shift_; LowLengthsScanMagicDivisorShift low_lengths_scan_magic_divisor_shift_;
UpLengths up_lengths_; UpLengths up_lengths_;
__host__ __device__ constexpr DynamicMerge_v2r2_magic_division() = default; __host__ __device__ constexpr Merge_v2r2_magic_division() = default;
__host__ __device__ constexpr DynamicMerge_v2r2_magic_division(const LowLengths& low_lengths) __host__ __device__ constexpr Merge_v2r2_magic_division(const LowLengths& low_lengths)
: low_lengths_{low_lengths}, : low_lengths_{low_lengths},
low_lengths_scan_{ low_lengths_scan_{
container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})}, container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})},
...@@ -1308,7 +1312,7 @@ struct DynamicMerge_v2r2_magic_division ...@@ -1308,7 +1312,7 @@ struct DynamicMerge_v2r2_magic_division
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
printf("DynamicMerge_v2r2_magic_division, "); printf("Merge_v2r2_magic_division, ");
printf("low_lengths_ "); printf("low_lengths_ ");
print_multi_index(low_lengths_); print_multi_index(low_lengths_);
printf("low_lengths_scan "); printf("low_lengths_scan ");
...@@ -1324,7 +1328,7 @@ struct DynamicMerge_v2r2_magic_division ...@@ -1324,7 +1328,7 @@ struct DynamicMerge_v2r2_magic_division
}; };
template <typename UpLengths, bool Use24BitIntegerCalculation> template <typename UpLengths, bool Use24BitIntegerCalculation>
struct DynamicUnMerge struct UnMerge
{ {
static constexpr index_t NDimUp = UpLengths::Size(); static constexpr index_t NDimUp = UpLengths::Size();
...@@ -1337,9 +1341,9 @@ struct DynamicUnMerge ...@@ -1337,9 +1341,9 @@ struct DynamicUnMerge
UpLengths up_lengths_; UpLengths up_lengths_;
UpLengthsScan up_lengths_scan_; UpLengthsScan up_lengths_scan_;
__host__ __device__ constexpr DynamicUnMerge() = default; __host__ __device__ constexpr UnMerge() = default;
__host__ __device__ constexpr DynamicUnMerge(const UpLengths& up_lengths) __host__ __device__ constexpr UnMerge(const UpLengths& up_lengths)
: up_lengths_{up_lengths}, : up_lengths_{up_lengths},
up_lengths_scan_{ up_lengths_scan_{
container_reverse_exclusive_scan(up_lengths, math::multiplies_v2{}, Number<1>{})} container_reverse_exclusive_scan(up_lengths, math::multiplies_v2{}, Number<1>{})}
...@@ -1414,7 +1418,7 @@ struct DynamicUnMerge ...@@ -1414,7 +1418,7 @@ struct DynamicUnMerge
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
printf("DynamicUnMerge, "); printf("UnMerge, ");
printf("up_lengths_"); printf("up_lengths_");
print_multi_index(up_lengths_); print_multi_index(up_lengths_);
printf("up_lengths_scan_"); printf("up_lengths_scan_");
...@@ -1424,13 +1428,13 @@ struct DynamicUnMerge ...@@ -1424,13 +1428,13 @@ struct DynamicUnMerge
}; };
template <typename LowerIndex> template <typename LowerIndex>
struct DynamicFreeze struct Freeze
{ {
LowerIndex low_idx_; LowerIndex low_idx_;
__host__ __device__ constexpr DynamicFreeze() = default; __host__ __device__ constexpr Freeze() = default;
__host__ __device__ constexpr DynamicFreeze(const LowerIndex& low_idx) : low_idx_{low_idx} {} __host__ __device__ constexpr Freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
...@@ -1483,22 +1487,22 @@ struct DynamicFreeze ...@@ -1483,22 +1487,22 @@ struct DynamicFreeze
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("DynamicFreeze"); printf("Freeze");
printf("low_idx_ %d", index_t{low_idx_}); printf("low_idx_ %d", index_t{low_idx_});
} }
}; };
// Insert a dangling upper dimension without lower dimension // Insert a dangling upper dimension without lower dimension
template <typename UpperLength> template <typename UpperLength>
struct DynamicInsert struct Insert
{ {
using UpLengths = decltype(make_tuple(UpperLength{})); using UpLengths = decltype(make_tuple(UpperLength{}));
UpLengths up_lengths_; UpLengths up_lengths_;
__host__ __device__ constexpr DynamicInsert() = default; __host__ __device__ constexpr Insert() = default;
__host__ __device__ constexpr DynamicInsert(const UpperLength& up_length) __host__ __device__ constexpr Insert(const UpperLength& up_length)
: up_lengths_{make_tuple(up_length)} : up_lengths_{make_tuple(up_length)}
{ {
} }
...@@ -1550,13 +1554,13 @@ struct DynamicInsert ...@@ -1550,13 +1554,13 @@ struct DynamicInsert
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("DynamicInsert"); printf("Insert");
print_multi_index(up_lengths_); print_multi_index(up_lengths_);
} }
}; };
template <typename VectorSize, typename UpLength> template <typename VectorSize, typename UpLength>
struct DynamicVectorize struct Vectorize
{ {
using LowerIndex = MultiIndex<1>; using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>;
...@@ -1566,10 +1570,10 @@ struct DynamicVectorize ...@@ -1566,10 +1570,10 @@ struct DynamicVectorize
UpLengths up_lengths_; UpLengths up_lengths_;
VectorSize vector_size_; VectorSize vector_size_;
__host__ __device__ constexpr DynamicVectorize() = default; __host__ __device__ constexpr Vectorize() = default;
__host__ __device__ constexpr DynamicVectorize(const VectorSize& vector_size, __host__ __device__ constexpr Vectorize(const VectorSize& vector_size,
const UpLength& up_length) const UpLength& up_length)
: vector_size_{vector_size}, up_lengths_{make_tuple(up_length)} : vector_size_{vector_size}, up_lengths_{make_tuple(up_length)}
{ {
} }
...@@ -1633,7 +1637,7 @@ struct DynamicVectorize ...@@ -1633,7 +1637,7 @@ struct DynamicVectorize
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
printf("DynamicVectorize, "); printf("Vectorize, ");
printf("up_lengths_"); printf("up_lengths_");
print_multi_index(up_lengths_); print_multi_index(up_lengths_);
printf("}"); printf("}");
...@@ -1641,7 +1645,7 @@ struct DynamicVectorize ...@@ -1641,7 +1645,7 @@ struct DynamicVectorize
}; };
template <typename LowLength, typename SliceBegin, typename SliceEnd> template <typename LowLength, typename SliceBegin, typename SliceEnd>
struct DynamicSlice struct Slice
{ {
using LowerIndex = MultiIndex<1>; using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>;
...@@ -1652,11 +1656,11 @@ struct DynamicSlice ...@@ -1652,11 +1656,11 @@ struct DynamicSlice
SliceBegin slice_begin_; SliceBegin slice_begin_;
SliceEnd slice_end_; SliceEnd slice_end_;
__host__ __device__ constexpr DynamicSlice() = default; __host__ __device__ constexpr Slice() = default;
__host__ __device__ constexpr DynamicSlice(const LowLength&, __host__ __device__ constexpr Slice(const LowLength&,
const SliceBegin& slice_begin, const SliceBegin& slice_begin,
const SliceEnd& slice_end) const SliceEnd& slice_end)
: up_lengths_{make_tuple(slice_end - slice_begin)}, : up_lengths_{make_tuple(slice_end - slice_begin)},
slice_begin_{slice_begin}, slice_begin_{slice_begin},
slice_end_{slice_end} slice_end_{slice_end}
...@@ -1724,7 +1728,7 @@ struct DynamicSlice ...@@ -1724,7 +1728,7 @@ struct DynamicSlice
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
printf("DynamicSlice, "); printf("Slice, ");
printf("up_lengths_"); printf("up_lengths_");
print_multi_index(up_lengths_); print_multi_index(up_lengths_);
printf("slice_begin_ %d", index_t{slice_begin_}); printf("slice_begin_ %d", index_t{slice_begin_});
......
#ifndef CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HELPER_HPP #ifndef CK_MULTI_INDEX_TRANSFORM_HELPER_HPP
#define CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HELPER_HPP #define CK_MULTI_INDEX_TRANSFORM_HELPER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_multi_index_transform.hpp" #include "multi_index_transform.hpp"
namespace ck { namespace ck {
template <typename LowLength> template <typename LowLength>
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length) __host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length)
{ {
return DynamicPassThrough<LowLength>{low_length}; return PassThrough<LowLength>{low_length};
} }
template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false> template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
...@@ -19,26 +19,25 @@ make_pad_transform(const LowLength& low_length, ...@@ -19,26 +19,25 @@ make_pad_transform(const LowLength& low_length,
const RightPad& right_pad, const RightPad& right_pad,
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{}) integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
{ {
return DynamicPad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{ return Pad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{low_length, left_pad, right_pad};
low_length, left_pad, right_pad};
} }
template <typename LowLength, typename LeftPad, bool SkipIsValidCheck = false> template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
__host__ __device__ constexpr auto make_left_pad_transform( __host__ __device__ constexpr auto make_left_pad_transform(
const LowLength& low_length, const LowLength& low_length,
const LeftPad& left_pad, const LeftPadLength& left_pad,
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{}) integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
{ {
return DynamicLeftPad<LowLength, LeftPad, SkipIsValidCheck>{low_length, left_pad}; return LeftPad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad};
} }
template <typename LowLength, typename RightPad, bool SkipIsValidCheck> template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
__host__ __device__ constexpr auto make_right_pad_transform( __host__ __device__ constexpr auto make_right_pad_transform(
const LowLength& low_length, const LowLength& low_length,
const RightPad& right_pad, const RightPadLength& right_pad,
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{}) integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
{ {
return DynamicRightPad<LowLength, RightPad, SkipIsValidCheck>{low_length, right_pad}; return RightPad<LowLength, RightPadLength, SkipIsValidCheck>{low_length, right_pad};
} }
template <typename UpLengths, template <typename UpLengths,
...@@ -47,19 +46,19 @@ template <typename UpLengths, ...@@ -47,19 +46,19 @@ template <typename UpLengths,
__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths, __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths,
const Coefficients& coefficients) const Coefficients& coefficients)
{ {
return DynamicEmbed<UpLengths, Coefficients>{up_lengths, coefficients}; return Embed<UpLengths, Coefficients>{up_lengths, coefficients};
} }
template <typename LowLengths> template <typename LowLengths>
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths) __host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
{ {
#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION #if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
return DynamicMerge_v1_carry_check<LowLengths>{low_lengths}; return Merge_v1_carry_check<LowLengths>{low_lengths};
#else #else
#if 1 #if 1
return DynamicMerge_v2_magic_division<LowLengths>{low_lengths}; return Merge_v2_magic_division<LowLengths>{low_lengths};
#else #else
return DynamicMerge_v2r2_magic_division<LowLengths>{low_lengths}; return Merge_v2r2_magic_division<LowLengths>{low_lengths};
#endif #endif
#endif #endif
} }
...@@ -68,7 +67,7 @@ template <typename LowLengths> ...@@ -68,7 +67,7 @@ template <typename LowLengths>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_merge_transform_v2_magic_division(const LowLengths& low_lengths) make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
{ {
return DynamicMerge_v2_magic_division<LowLengths>{low_lengths}; return Merge_v2_magic_division<LowLengths>{low_lengths};
} }
template <typename UpLengths, bool Use24BitIntegerCalculation = false> template <typename UpLengths, bool Use24BitIntegerCalculation = false>
...@@ -76,13 +75,13 @@ __host__ __device__ constexpr auto make_unmerge_transform( ...@@ -76,13 +75,13 @@ __host__ __device__ constexpr auto make_unmerge_transform(
const UpLengths& up_lengths, const UpLengths& up_lengths,
integral_constant<bool, Use24BitIntegerCalculation> = integral_constant<bool, false>{}) integral_constant<bool, Use24BitIntegerCalculation> = integral_constant<bool, false>{})
{ {
return DynamicUnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths}; return UnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
} }
template <typename LowerIndex> template <typename LowerIndex>
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx) __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
{ {
return DynamicFreeze<LowerIndex>{low_idx}; return Freeze<LowerIndex>{low_idx};
} }
template <typename LowLength, typename SliceBegin, typename SliceEnd> template <typename LowLength, typename SliceBegin, typename SliceEnd>
...@@ -90,14 +89,14 @@ __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_len ...@@ -90,14 +89,14 @@ __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_len
const SliceBegin& slice_begin, const SliceBegin& slice_begin,
const SliceEnd& slice_end) const SliceEnd& slice_end)
{ {
return DynamicSlice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end}; return Slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
} }
template <typename VectorSize, typename UpLength> template <typename VectorSize, typename UpLength>
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size, __host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
const UpLength& up_length) const UpLength& up_length)
{ {
return DynamicVectorize<VectorSize, UpLength>{vector_size, up_length}; return Vectorize<VectorSize, UpLength>{vector_size, up_length};
} }
} // namespace ck } // namespace ck
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_TENSOR_ADAPTOR_HPP #define CK_TENSOR_ADAPTOR_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
namespace ck { namespace ck {
......
#ifndef CK_DYNAMIC_TENSOR_DESCRIPTOR_HPP #ifndef CK_TENSOR_DESCRIPTOR_HPP
#define CK_DYNAMIC_TENSOR_DESCRIPTOR_HPP #define CK_TENSOR_DESCRIPTOR_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_multi_index_transform.hpp" #include "multi_index_transform.hpp"
namespace ck { namespace ck {
template <index_t NDimHidden, typename VisibleDimensionIds> template <index_t NDimHidden, typename VisibleDimensionIds>
struct DynamicTensorCoordinate; struct TensorCoordinate;
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack> template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
struct DynamicTensorCoordinateIterator; struct TensorCoordinateIterator;
// Transforms: Tuple<transforms...> // Transforms: Tuple<transforms...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...> // LowerDimensionIdss : Tuple<Sequence<...>, ...>
...@@ -21,7 +21,7 @@ template <typename Transforms, ...@@ -21,7 +21,7 @@ template <typename Transforms,
typename UpperDimensionIdss, typename UpperDimensionIdss,
typename VisibleDimensionIds, typename VisibleDimensionIds,
typename ElementSpaceSize> typename ElementSpaceSize>
struct DynamicTensorDescriptor struct TensorDescriptor
{ {
// TODO make these private // TODO make these private
__host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); }
...@@ -105,16 +105,16 @@ struct DynamicTensorDescriptor ...@@ -105,16 +105,16 @@ struct DynamicTensorDescriptor
using VisibleIndex = MultiIndex<ndim_visible_>; using VisibleIndex = MultiIndex<ndim_visible_>;
using HiddenIndex = MultiIndex<ndim_hidden_>; using HiddenIndex = MultiIndex<ndim_hidden_>;
using Coordinate = DynamicTensorCoordinate<ndim_hidden_, VisibleDimensionIds>; using Coordinate = TensorCoordinate<ndim_hidden_, VisibleDimensionIds>;
// may be index_t or Number<> // may be index_t or Number<>
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>; using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
public: public:
__host__ __device__ constexpr DynamicTensorDescriptor() = default; __host__ __device__ constexpr TensorDescriptor() = default;
__host__ __device__ constexpr DynamicTensorDescriptor(const Transforms& transforms, __host__ __device__ constexpr TensorDescriptor(const Transforms& transforms,
ElementSpaceSize element_space_size) ElementSpaceSize element_space_size)
: transforms_{transforms}, : transforms_{transforms},
element_size_{InitializeElementSize(transforms)}, element_size_{InitializeElementSize(transforms)},
element_space_size_{element_space_size} element_space_size_{element_space_size}
...@@ -159,7 +159,7 @@ struct DynamicTensorDescriptor ...@@ -159,7 +159,7 @@ struct DynamicTensorDescriptor
{ {
static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension"); static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension");
return make_dynamic_tensor_coordinate(*this, idx).GetOffset(); return make_tensor_coordinate(*this, idx).GetOffset();
} }
// TODO make these private // TODO make these private
...@@ -196,7 +196,7 @@ struct DynamicTensorDescriptor ...@@ -196,7 +196,7 @@ struct DynamicTensorDescriptor
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
printf("DynamicTensorDescriptor, "); printf("TensorDescriptor, ");
static_for<0, ntransform_, 1>{}([&](auto i) { static_for<0, ntransform_, 1>{}([&](auto i) {
printf("transforms: "); printf("transforms: ");
transforms_[i].Print(); transforms_[i].Print();
...@@ -217,7 +217,7 @@ struct DynamicTensorDescriptor ...@@ -217,7 +217,7 @@ struct DynamicTensorDescriptor
}; };
template <index_t NDimHidden, typename VisibleDimensionIds> template <index_t NDimHidden, typename VisibleDimensionIds>
struct DynamicTensorCoordinate struct TensorCoordinate
{ {
// TODO make these private // TODO make these private
static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size(); static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size();
...@@ -226,9 +226,9 @@ struct DynamicTensorCoordinate ...@@ -226,9 +226,9 @@ struct DynamicTensorCoordinate
using VisibleIndex = MultiIndex<ndim_visible_>; using VisibleIndex = MultiIndex<ndim_visible_>;
public: public:
__host__ __device__ constexpr DynamicTensorCoordinate() = default; __host__ __device__ constexpr TensorCoordinate() = default;
__host__ __device__ constexpr DynamicTensorCoordinate(const HiddenIndex& idx_hidden) __host__ __device__ constexpr TensorCoordinate(const HiddenIndex& idx_hidden)
: idx_hidden_{idx_hidden} : idx_hidden_{idx_hidden}
{ {
} }
...@@ -252,16 +252,17 @@ struct DynamicTensorCoordinate ...@@ -252,16 +252,17 @@ struct DynamicTensorCoordinate
}; };
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack> template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
struct DynamicTensorCoordinateIterator struct TensorCoordinateIterator
{ {
// TODO make these private // TODO make these private
using VisibleIndex = MultiIndex<NDimVisible>; using VisibleIndex = MultiIndex<NDimVisible>;
public: public:
__host__ __device__ constexpr DynamicTensorCoordinateIterator() = default; __host__ __device__ constexpr TensorCoordinateIterator() = default;
__host__ __device__ constexpr DynamicTensorCoordinateIterator( __host__
const VisibleIndex& idx_diff_visible, const MultiIndex<NTransform>& do_transforms) __device__ constexpr TensorCoordinateIterator(const VisibleIndex& idx_diff_visible,
const MultiIndex<NTransform>& do_transforms)
: idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms} : idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms}
{ {
} }
...@@ -283,7 +284,7 @@ struct DynamicTensorCoordinateIterator ...@@ -283,7 +284,7 @@ struct DynamicTensorCoordinateIterator
// TODO: How to fix this? It uses an struct instead of lambda because lambda // TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used // doesn't have constructor, and to put it outside the scope where it is used
// (transform_dynamic_tensor_descriptor) because template cannot be defined inside a function // (transform_tensor_descriptor) because template cannot be defined inside a function
// template // template
template <typename NewTransforms> template <typename NewTransforms>
struct lambda_get_up_dim_num struct lambda_get_up_dim_num
...@@ -301,10 +302,10 @@ template <typename OldTensorDescriptor, ...@@ -301,10 +302,10 @@ template <typename OldTensorDescriptor,
typename NewLowerDimensionOldVisibleIdss, typename NewLowerDimensionOldVisibleIdss,
typename NewUpperDimensionNewVisibleIdss> typename NewUpperDimensionNewVisibleIdss>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const NewTransforms& new_transforms, const NewTransforms& new_transforms,
NewLowerDimensionOldVisibleIdss, NewLowerDimensionOldVisibleIdss,
NewUpperDimensionNewVisibleIdss) NewUpperDimensionNewVisibleIdss)
{ {
// sanity check // sanity check
{ {
...@@ -376,17 +377,17 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, ...@@ -376,17 +377,17 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const auto element_space_size = old_tensor_desc.GetElementSpaceSize(); const auto element_space_size = old_tensor_desc.GetElementSpaceSize();
return DynamicTensorDescriptor<remove_cv_t<decltype(all_transforms)>, return TensorDescriptor<remove_cv_t<decltype(all_transforms)>,
remove_cv_t<decltype(all_low_dim_hidden_idss)>, remove_cv_t<decltype(all_low_dim_hidden_idss)>,
remove_cv_t<decltype(all_up_dim_hidden_idss)>, remove_cv_t<decltype(all_up_dim_hidden_idss)>,
remove_cv_t<decltype(new_visible_dim_hidden_ids)>, remove_cv_t<decltype(new_visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{all_transforms, remove_cv_t<decltype(element_space_size)>>{all_transforms,
element_space_size}; element_space_size};
} }
template <typename TensorDesc, typename VisibleIndex> template <typename TensorDesc, typename VisibleIndex>
__host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDesc& tensor_desc, __host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
const VisibleIndex& idx_visible) const VisibleIndex& idx_visible)
{ {
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
"wrong! # of dimension inconsistent"); "wrong! # of dimension inconsistent");
...@@ -416,13 +417,13 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe ...@@ -416,13 +417,13 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
set_container_subset(idx_hidden, dims_low, idx_low); set_container_subset(idx_hidden, dims_low, idx_low);
}); });
return DynamicTensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden}; return TensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden};
} }
// UpdateLowerIndexHack: Sequence<...> // UpdateLowerIndexHack: Sequence<...>
// HACK: control UpdateLowerIndex // HACK: control UpdateLowerIndex
template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack> template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack>
__host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator( __host__ __device__ constexpr auto make_tensor_coordinate_iterator(
const TensorDesc&, const VisibleIndex& idx_diff_visible, UpdateLowerIndexHack) const TensorDesc&, const VisibleIndex& idx_diff_visible, UpdateLowerIndexHack)
{ {
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
...@@ -470,23 +471,24 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator( ...@@ -470,23 +471,24 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator(
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low); set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
}); });
return DynamicTensorCoordinateIterator<ntransform, ndim_visible, UpdateLowerIndexHack>{ return TensorCoordinateIterator<ntransform, ndim_visible, UpdateLowerIndexHack>{
idx_diff_visible, do_transforms}; idx_diff_visible, do_transforms};
} }
template <typename TensorDesc, typename VisibleIndex> template <typename TensorDesc, typename VisibleIndex>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_dynamic_tensor_coordinate_iterator(const TensorDesc&, const VisibleIndex& idx_diff_visible) make_tensor_coordinate_iterator(const TensorDesc&, const VisibleIndex& idx_diff_visible)
{ {
constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
return make_dynamic_tensor_coordinate_iterator( return make_tensor_coordinate_iterator(
TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{}); TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{});
} }
template <typename TensorDesc, typename TensorCoord, typename TensorCoordIterator> template <typename TensorDesc, typename TensorCoord, typename TensorCoordIterator>
__host__ __device__ constexpr void move_dynamic_tensor_coordinate( __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc,
const TensorDesc& tensor_desc, TensorCoord& coord, const TensorCoordIterator& coord_iterator) TensorCoord& coord,
const TensorCoordIterator& coord_iterator)
{ {
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
...@@ -524,7 +526,7 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate( ...@@ -524,7 +526,7 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
MultiIndex<dims_low.Size()> idx_diff_low; MultiIndex<dims_low.Size()> idx_diff_low;
// HACK: control UpdateLowerIndex for DynamicMerge using hack // HACK: control UpdateLowerIndex for Merge using hack
constexpr index_t Hack = decltype(coord_iterator.update_lower_index_hack_)::At(itran); constexpr index_t Hack = decltype(coord_iterator.update_lower_index_hack_)::At(itran);
tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{}); tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
...@@ -585,11 +587,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& ...@@ -585,11 +587,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
} }
template <typename TensorDesc> template <typename TensorDesc>
using DynamicTensorCoordinate_t = decltype(make_dynamic_tensor_coordinate( using TensorCoordinate_t = decltype(make_tensor_coordinate(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{})); TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
template <typename TensorDesc> template <typename TensorDesc>
using DynamicTensorCoordinateIterator_t = decltype(make_dynamic_tensor_coordinate_iterator( using TensorCoordinateIterator_t = decltype(make_tensor_coordinate_iterator(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{})); TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
} // namespace ck } // namespace ck
......
#ifndef CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP #ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP #define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
namespace ck { namespace ck {
...@@ -38,9 +38,8 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt ...@@ -38,9 +38,8 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
template <typename... Lengths, template <typename... Lengths,
typename... Strides, typename... Strides,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false> typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto make_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths, const Tuple<Strides...>& strides)
const Tuple<Strides...>& strides)
{ {
constexpr index_t N = sizeof...(Lengths); constexpr index_t N = sizeof...(Lengths);
...@@ -75,12 +74,12 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths, ...@@ -75,12 +74,12 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{}); calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{});
#endif #endif
return DynamicTensorDescriptor<remove_cv_t<decltype(transforms)>, return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>, remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>, remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>, remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{transforms, remove_cv_t<decltype(element_space_size)>>{transforms,
element_space_size}; element_space_size};
} }
// Lengths... can be: // Lengths... can be:
...@@ -88,7 +87,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths, ...@@ -88,7 +87,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
// 2) Number<>, which is known at compile-time // 2) Number<>, which is known at compile-time
template <typename... Lengths> template <typename... Lengths>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths) make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
{ {
constexpr index_t N = sizeof...(Lengths); constexpr index_t N = sizeof...(Lengths);
...@@ -103,17 +102,17 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths) ...@@ -103,17 +102,17 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths)
const auto element_space_size = container_reduce(lengths, math::multiplies_v2{}, Number<1>{}); const auto element_space_size = container_reduce(lengths, math::multiplies_v2{}, Number<1>{});
return DynamicTensorDescriptor<remove_cv_t<decltype(transforms)>, return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>, remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>, remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>, remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{transforms, remove_cv_t<decltype(element_space_size)>>{transforms,
element_space_size}; element_space_size};
} }
template <typename... Lengths, typename Align> template <typename... Lengths, typename Align>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths, Align align) make_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths, Align align)
{ {
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -143,7 +142,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths ...@@ -143,7 +142,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths
}, },
Number<N>{}); Number<N>{});
return make_dynamic_naive_tensor_descriptor_v2(lengths, strides); return make_naive_tensor_descriptor_v2(lengths, strides);
} }
} // namespace ck } // namespace ck
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_adaptor.hpp" #include "tensor_adaptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_contraction_dlops.hpp" #include "threadwise_contraction_dlops.hpp"
namespace ck { namespace ck {
...@@ -73,7 +73,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -73,7 +73,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeAKM0M1BlockDescriptor(const AKMBlockDesc& /* a_k_m_block_desc */) MakeAKM0M1BlockDescriptor(const AKMBlockDesc& /* a_k_m_block_desc */)
{ {
const auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor( const auto a_k_m0_m1_block_desc = transform_tensor_descriptor(
AKMBlockDesc{}, AKMBlockDesc{},
make_tuple(make_pass_through_transform(Number<K>{}), make_tuple(make_pass_through_transform(Number<K>{}),
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{}))), make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{}))),
...@@ -86,7 +86,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -86,7 +86,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBKN0N1BlockDescriptor(const BKNBlockDesc& /* b_k_n_block_desc */) MakeBKN0N1BlockDescriptor(const BKNBlockDesc& /* b_k_n_block_desc */)
{ {
const auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor( const auto b_k_n0_n1_block_desc = transform_tensor_descriptor(
BKNBlockDesc{}, BKNBlockDesc{},
make_tuple(make_pass_through_transform(Number<K>{}), make_tuple(make_pass_through_transform(Number<K>{}),
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{}))), make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{}))),
...@@ -357,34 +357,32 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -357,34 +357,32 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
private: private:
// A[K, M0, M1] // A[K, M0, M1]
static constexpr auto a_k_m0_m1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_k_m0_m1_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{})); make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}));
// B[K, N0, N1] // B[K, N0, N1]
static constexpr auto b_k_n0_n1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_k_n0_n1_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{})); make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}));
using AThreadCopy = using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA, FloatA,
FloatA, decltype(a_k_m0_m1_block_desc_),
decltype(a_k_m0_m1_block_desc_), decltype(a_k_m0_m1_thread_desc_),
decltype(a_k_m0_m1_thread_desc_), Sequence<KPerThread, 1, M1PerThreadM11>,
Sequence<KPerThread, 1, M1PerThreadM11>, Sequence<0, 1, 2>,
Sequence<0, 1, 2>, 2,
2, AThreadCopyScalarPerVector_M11,
AThreadCopyScalarPerVector_M11, 1>;
1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
using BThreadCopy = FloatB,
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB, decltype(b_k_n0_n1_block_desc_),
FloatB, decltype(b_k_n0_n1_thread_desc_),
decltype(b_k_n0_n1_block_desc_), Sequence<KPerThread, 1, N1PerThreadN11>,
decltype(b_k_n0_n1_thread_desc_), Sequence<0, 1, 2>,
Sequence<KPerThread, 1, N1PerThreadN11>, 2,
Sequence<0, 1, 2>, BThreadCopyScalarPerVector_N11,
2, 1>;
BThreadCopyScalarPerVector_N11,
1>;
CIndex c_thread_origin_data_idx_; CIndex c_thread_origin_data_idx_;
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_adaptor.hpp" #include "tensor_adaptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" #include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_contraction_dlops.hpp" #include "threadwise_contraction_dlops.hpp"
namespace ck { namespace ck {
...@@ -75,7 +75,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B ...@@ -75,7 +75,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1) MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
{ {
const auto a_block_bk0_bm0_bm1_bk1 = transform_dynamic_tensor_descriptor( const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor(
a_block_desc_bk0_bm_bk1, a_block_desc_bk0_bm_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}), make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})), make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
...@@ -89,7 +89,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B ...@@ -89,7 +89,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1) MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
{ {
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_dynamic_tensor_descriptor( const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor(
b_block_desc_bk0_bn_bk1, b_block_desc_bk0_bn_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}), make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})), make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
...@@ -372,15 +372,15 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B ...@@ -372,15 +372,15 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
private: private:
// A[BK0, BM0, BM1, BK1] // A[BK0, BM0, BM1, BK1]
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_naive_tensor_descriptor_packed(make_tuple(
Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThreadBM11>{}, Number<BK1>{})); Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThreadBM11>{}, Number<BK1>{}));
// B[BK0, BN0, BN1, BK1] // B[BK0, BN0, BN1, BK1]
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_naive_tensor_descriptor_packed(make_tuple(
Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThreadBN11>{}, Number<BK1>{})); Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThreadBN11>{}, Number<BK1>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1< using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatA, FloatA,
FloatA, FloatA,
decltype(a_block_desc_bk0_bm0_bm1_bk1_), decltype(a_block_desc_bk0_bm0_bm1_bk1_),
...@@ -390,7 +390,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B ...@@ -390,7 +390,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1< using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatB, FloatB,
FloatB, FloatB,
decltype(b_block_desc_bk0_bn0_bn1_bk1_), decltype(b_block_desc_bk0_bn0_bn1_bk1_),
......
...@@ -31,25 +31,24 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -31,25 +31,24 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
// HACK: fix this @Jing Zhang // HACK: fix this @Jing Zhang
static constexpr index_t KPerThreadSubC = 4; static constexpr index_t KPerThreadSubC = 4;
static constexpr auto a_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{})); make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{}));
static constexpr auto b_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
static constexpr auto c_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
using AThreadCopy = using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA, FloatA,
FloatA, BlockMatrixA,
BlockMatrixA, decltype(a_thread_mtx_),
decltype(a_thread_mtx_), Sequence<EPerThreadLoop, KPerThreadSubC>,
Sequence<EPerThreadLoop, KPerThreadSubC>, Sequence<0, 1>,
Sequence<0, 1>, 1,
1, ThreadGemmADataPerRead_K,
ThreadGemmADataPerRead_K, 1>;
1>;
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP #define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp" #include "xdlops_gemm.hpp"
namespace ck { namespace ck {
...@@ -191,35 +191,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -191,35 +191,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
private: private:
// A[K, M] // A[K, M]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_desc_ =
make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
// B[K, N] // B[K, N]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_thread_desc_ =
make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto c_thread_desc_ =
make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB, using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
ABlockDesc, ABlockDesc,
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, MRepeat, 1, K1>, Sequence<1, MRepeat, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
K1, K1,
1>; 1>;
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
BBlockDesc, BBlockDesc,
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, NRepeat, 1, K1>, Sequence<1, NRepeat, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
K1, K1,
1>; 1>;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_; BThreadCopy b_thread_copy_;
...@@ -486,35 +486,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -486,35 +486,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
private: private:
// A[K, M] // A[K, M]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_desc_ =
make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
// B[K, N] // B[K, N]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_thread_desc_ =
make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto c_thread_desc_ =
make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB, using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
ABlockDesc, ABlockDesc,
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, 1, 1, K1>, Sequence<1, 1, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
1, // K1, 1, // K1,
1>; 1>;
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
BBlockDesc, BBlockDesc,
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, 1, 1, K1>, Sequence<1, 1, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
1, // K1, 1, // K1,
1>; 1>;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_; BThreadCopy b_thread_copy_;
......
#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP #ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP #define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
namespace ck { namespace ck {
// this version does following things to avoid scratch memory issue // this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer // 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
...@@ -33,16 +33,16 @@ template <index_t BlockSize, ...@@ -33,16 +33,16 @@ template <index_t BlockSize,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseDynamicTensorSliceTransfer_v4 struct BlockwiseTensorSliceTransfer_v4
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4(const SrcDesc& src_desc, __device__ constexpr BlockwiseTensorSliceTransfer_v4(const SrcDesc& src_desc,
const Index& src_block_slice_origin, const Index& src_block_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin) const Index& dst_block_slice_origin)
: threadwise_transfer_( : threadwise_transfer_(
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>()) src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
...@@ -147,22 +147,22 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -147,22 +147,22 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseDynamicTensorSliceTransfer_v3<ThreadSliceLengths, ThreadwiseTensorSliceTransfer_v3<ThreadSliceLengths,
DstInMemOp, DstInMemOp,
SrcData, SrcData,
DstData, DstData,
SrcDesc, SrcDesc,
DstDesc, DstDesc,
SrcDimAccessOrder, SrcDimAccessOrder,
DstDimAccessOrder, DstDimAccessOrder,
SrcVectorDim, SrcVectorDim,
DstVectorDim, DstVectorDim,
SrcScalarPerVector, SrcScalarPerVector,
DstScalarPerVector, DstScalarPerVector,
SrcScalarStrideInVector, SrcScalarStrideInVector,
DstScalarStrideInVector, DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>; ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_; ThreadwiseTransfer threadwise_transfer_;
}; };
......
#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP #ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP #define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" #include "threadwise_tensor_slice_transfer_v2.hpp"
namespace ck { namespace ck {
// this version does following things to avoid scratch memory issue // this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer // 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
...@@ -31,17 +31,16 @@ template <index_t BlockSize, ...@@ -31,17 +31,16 @@ template <index_t BlockSize,
typename DstVectorTensorContiguousDimOrder, typename DstVectorTensorContiguousDimOrder,
bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseDynamicTensorSliceTransfer_v4r1 struct BlockwiseTensorSliceTransfer_v4r1
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4r1( __device__ constexpr BlockwiseTensorSliceTransfer_v4r1(const SrcDesc& src_desc,
const SrcDesc& src_desc, const Index& src_block_slice_origin,
const Index& src_block_slice_origin, const DstDesc& dst_desc,
const DstDesc& dst_desc, const Index& dst_block_slice_origin)
const Index& dst_block_slice_origin)
: threadwise_transfer_( : threadwise_transfer_(
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>()) src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
...@@ -136,20 +135,20 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1 ...@@ -136,20 +135,20 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseDynamicTensorSliceTransfer_v3r1<ThreadSliceLengths, ThreadwiseTensorSliceTransfer_v3r1<ThreadSliceLengths,
DstInMemOp, DstInMemOp,
SrcData, SrcData,
DstData, DstData,
SrcDesc, SrcDesc,
DstDesc, DstDesc,
SrcDimAccessOrder, SrcDimAccessOrder,
DstDimAccessOrder, DstDimAccessOrder,
SrcVectorTensorLengths, SrcVectorTensorLengths,
DstVectorTensorLengths, DstVectorTensorLengths,
SrcVectorTensorContiguousDimOrder, SrcVectorTensorContiguousDimOrder,
DstVectorTensorContiguousDimOrder, DstVectorTensorContiguousDimOrder,
ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>; ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_; ThreadwiseTransfer threadwise_transfer_;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment