Commit 0c8aa120 authored by ltqin's avatar ltqin
Browse files

V4R4R4XDLNHWC fp16

parent 9da60a34
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// A: in
// B: wei
// C: out
// GemmM = N * Ho * Wo
// GemmN = K
// GemmK = Y * X * C
template <typename... In,
typename... Wei,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value,
typename GemmKBatchType>
__host__ __device__ constexpr auto
transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad(
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>,
GemmKBatchType GemmKBatch)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = Y * X * C;
const auto GemmN = K;
const auto GemmKTotal = N * Ho * Wo;
const auto GemmK = GemmKTotal / GemmKBatch;
const auto GemmK0 = GemmK / GemmK1;
// A: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmktotal_gemmm_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: output tensor
const auto out_gemmktotal_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
...@@ -20,8 +20,7 @@ template <typename... In, ...@@ -20,8 +20,7 @@ template <typename... In,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads, typename InRightPads,
index_t GemmK1Value, index_t GemmK1Value>
typename GemmKBatchType>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc, const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
...@@ -31,8 +30,7 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( ...@@ -31,8 +30,7 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
Number<GemmK1Value>, Number<GemmK1Value>)
GemmKBatchType GemmKBatch)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -66,11 +64,10 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( ...@@ -66,11 +64,10 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const auto InRightPadH = in_right_pads[I0]; const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1]; const auto InRightPadW = in_right_pads[I1];
const auto GemmM = Y * X * C; const auto GemmM = Y * X * C;
const auto GemmN = K; const auto GemmN = K;
const auto GemmKTotal = N * Ho * Wo; const auto GemmK = N * Ho * Wo;
const auto GemmK = GemmKTotal / GemmKBatch; const auto GemmK0 = GemmK / GemmK1;
const auto GemmK0 = GemmK / GemmK1;
// A: input tensor // A: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
...@@ -91,30 +88,33 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( ...@@ -91,30 +88,33 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmktotal_gemmm_grid_desc = const auto in_gemmk_gemmm_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)), make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmk0_gemmm_gemmk1_grid_desc =
in_gemmktotal_gemmm_grid_desc, transform_tensor_descriptor(in_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)), make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: output tensor // B: output tensor
const auto out_gemmktotal_gemmn_grid_desc = const auto out_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(out_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
...@@ -123,8 +123,8 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( ...@@ -123,8 +123,8 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, out_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmm_gemmn_grid_desc);
} }
......
#include <unistd.h> #include <unistd.h>
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" #include "transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r4.hpp" #include "driver_gemm_xdlops_v2r4.hpp"
template <typename TIn, template <typename TIn,
...@@ -109,7 +109,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_ ...@@ -109,7 +109,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif #endif
const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad(
in_n_hi_wi_c_desc, in_n_hi_wi_c_desc,
wei_k_y_x_c_desc, wei_k_y_x_c_desc,
out_n_ho_wo_k_desc, out_n_ho_wo_k_desc,
......
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" #include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp" #include "driver_gemm_xdlops_v2r3.hpp"
template <typename TInWei, template <typename TIn,
typename TWei,
typename TAcc, typename TAcc,
typename TOut, typename TOut,
typename InLengths, typename InLengths,
...@@ -22,8 +23,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh ...@@ -22,8 +23,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c, const Tensor<TIn>& in_n_hi_wi_c,
Tensor<TInWei>& wei_k_y_x_c, Tensor<TWei>& wei_k_y_x_c,
const Tensor<TOut>& out_n_ho_wo_k, const Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
...@@ -36,8 +37,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh ...@@ -36,8 +37,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
...@@ -194,9 +195,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh ...@@ -194,9 +195,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
{ {
float ave_time = driver_gemm_xdlops_v2r3< float ave_time = driver_gemm_xdlops_v2r3<
BlockSize, BlockSize,
TInWei, TIn,
TAcc, TAcc,
TOut, TWei,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
...@@ -234,9 +235,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh ...@@ -234,9 +235,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat false // CAccessOrderMRepeatNRepeat
>(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), >(static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
in_gemmk0_gemmm_gemmk1_grid_desc, in_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmn_gemmk1_grid_desc, out_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc, wei_gemmm_gemmn_grid_desc,
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#define USE_DYNAMIC_MODE 1 #define USE_DYNAMIC_MODE 1
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 1 #define USE_CONV_WRW_V4R4R2_XDL_NCHW 1
#define USE_CONV_WRW_V4R4R4_XDL_NHWC 0 #define USE_CONV_WRW_V4R4R4_XDL_NHWC 1
#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0 #define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0
#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 1 #define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 1
#define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1 #define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1
...@@ -306,6 +306,7 @@ int main(int argc, char* argv[]) ...@@ -306,6 +306,7 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nhwc(); const auto tmp = f_make_for_device_nhwc();
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk<in_data_t, device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk<in_data_t,
wei_data_t,
acc_data_t, acc_data_t,
out_data_t>( out_data_t>(
tmp[I0], tmp[I0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment