Commit 3c150a8f authored by ltqin's avatar ltqin
Browse files

v4r4r5 fp16

parent 13323ce4
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
#include "transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp" #include "transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r4.hpp" #include "driver_gemm_xdlops_v2r4.hpp"
template <typename TInWei, template <typename TIn,
typename TWei,
typename TAcc, typename TAcc,
typename TOut, typename TOut,
typename InLengths, typename InLengths,
...@@ -23,8 +24,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_ ...@@ -23,8 +24,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c, const Tensor<TIn>& in_n_hi_wi_c,
Tensor<TInWei>& wei_k_y_x_c, Tensor<TWei>& wei_k_y_x_c,
const Tensor<TOut>& out_n_ho_wo_k, const Tensor<TOut>& out_n_ho_wo_k,
GemmKBatchType GemmKBatch, GemmKBatchType GemmKBatch,
ck::index_t nrepeat) ck::index_t nrepeat)
...@@ -38,8 +39,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_ ...@@ -38,8 +39,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
...@@ -204,9 +205,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_ ...@@ -204,9 +205,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_
{ {
float ave_time = driver_gemm_xdlops_v2r4< float ave_time = driver_gemm_xdlops_v2r4<
BlockSize, BlockSize,
TInWei, TIn,
TAcc, TAcc,
TOut, TWei,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum_t::AtomicAdd,
decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc),
...@@ -245,8 +246,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_ ...@@ -245,8 +246,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_
decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat false // CAccessOrderMRepeatNRepeat
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), >(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc, wei_gemmm_gemmn_grid_desc,
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 0 #define USE_CONV_WRW_V4R4R2_XDL_NCHW 0
#define USE_CONV_WRW_V4R4R4_XDL_NHWC 0 #define USE_CONV_WRW_V4R4R4_XDL_NHWC 0
#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0 #define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0
#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 1 #define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 0
#define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1 #define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1
enum ConvBackwardWeightAlgo enum ConvBackwardWeightAlgo
...@@ -125,18 +125,21 @@ int main(int argc, char* argv[]) ...@@ -125,18 +125,21 @@ int main(int argc, char* argv[])
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
#endif #endif
#if 1 #if 0
using in_data_t = float; using in_data_t = float;
using wei_data_t = float;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = float; using out_data_t = float;
#elif 1 #elif 1
using in_data_t = half_t; using in_data_t = half_t;
using acc_data_t = float;
using out_data_t = half_t; using out_data_t = half_t;
using acc_data_t = float;
using wei_data_t = float;
#elif 1 #elif 1
using in_data_t = int8_t; using in_data_t = int8_t;
using acc_data_t = int32_t;
using out_data_t = int8_t; using out_data_t = int8_t;
using acc_data_t = int32_t;
using wei_data_t = int8_t;
#endif #endif
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
...@@ -177,8 +180,8 @@ int main(int argc, char* argv[]) ...@@ -177,8 +180,8 @@ int main(int argc, char* argv[])
} }
Tensor<in_data_t> in(in_lengths_host); Tensor<in_data_t> in(in_lengths_host);
Tensor<in_data_t> wei_device(wei_lengths_host); Tensor<wei_data_t> wei_device(wei_lengths_host);
Tensor<out_data_t> wei_host(wei_lengths_host); Tensor<wei_data_t> wei_host(wei_lengths_host);
Tensor<out_data_t> out(out_lengths_host); Tensor<out_data_t> out(out_lengths_host);
std::cout << "layout: " << layout << std::endl; std::cout << "layout: " << layout << std::endl;
...@@ -385,6 +388,7 @@ int main(int argc, char* argv[]) ...@@ -385,6 +388,7 @@ int main(int argc, char* argv[])
device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk< device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk<
in_data_t, in_data_t,
wei_data_t,
acc_data_t, acc_data_t,
out_data_t>(tmp[I0], out_data_t>(tmp[I0],
tmp[I1], tmp[I1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment