Commit d0b43b08 authored by ltqin's avatar ltqin
Browse files

change device conv variable names

parent 4ec493ec
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk.hpp" #include "transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp" #include "driver_gemm_xdlops_v3r1.hpp"
template <typename TInWei, template <typename TInWei,
typename TAcc, typename TAcc,
...@@ -176,14 +176,14 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk( ...@@ -176,14 +176,14 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk(
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4; constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 2, 8>; using GemmABlockTransferThreadSliceLengths_GemmG_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 64, 1>; using GemmABlockTransferThreadClusterLengths_GemmG_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 8>; using GemmBBlockTransferThreadSliceLengths_GemmG_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 64, 1>; using GemmBBlockTransferThreadClusterLengths_GemmG_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
...@@ -228,63 +228,69 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk( ...@@ -228,63 +228,69 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk(
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
Number<GemmK1>{}); Number<GemmK1>{});
/*
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; const auto in_gemmg_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; const auto wei_gemmg_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2]; const auto out_gemmg_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = constexpr auto in_gemmg_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmG
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmM
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmG
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmM
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN constexpr auto wei_gemmg_gemmk0_gemmn_gemmk1_grid_step_hacks =
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmG
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmG
constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmK0
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 constexpr auto out_gemmg_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: M1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 7+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 8+: N2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: M1
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M3
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 7-: M4
Sequence<0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 8-: N2
constexpr auto in_gemmg_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
constexpr auto wei_gemmg_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0>{};
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = driver_gemm_xdlops_v2r3< float ave_time = driver_gemm_xdlops_v3r1<
BlockSize, BlockSize,
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmg_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmg_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(out_gemmg_gemmm_gemmn_grid_desc),
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
...@@ -293,63 +299,64 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk( ...@@ -293,63 +299,64 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk(
GemmK1, GemmK1,
MRepeat, MRepeat,
NRepeat, NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, GemmABlockTransferThreadSliceLengths_GemmG_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, GemmABlockTransferThreadClusterLengths_GemmG_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>, Sequence<0, 2, 1, 3>,
Sequence<1, 0, 2>, Sequence<0, 2, 1, 3>,
2, 3,
GemmABlockTransferSrcScalarPerVector_GemmK1, GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1, GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, GemmBBlockTransferThreadSliceLengths_GemmG_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, GemmBBlockTransferThreadClusterLengths_GemmG_GemmK0_GemmN_GemmK1,
Sequence<1, 0, 2>, Sequence<0, 2, 1, 3>,
Sequence<1, 0, 2>, Sequence<0, 2, 1, 3>,
2, 3,
GemmBBlockTransferSrcScalarPerVector_GemmK1, GemmBBlockTransferSrcScalarPerVector_GemmK1,
GemmBBlockTransferDstScalarPerVector_GemmK1, GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, Sequence<0, 3, 4, 1, 2, 8, 6, 5, 7>,
7, 8,
GemmCThreadTransferDstScalarPerVector, GemmCThreadTransferDstScalarPerVector,
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(in_gemmg_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), decltype(wei_gemmg_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(out_gemmg_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(in_gemmg_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), decltype(wei_gemmg_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat false // CAccessOrderMRepeatNRepeat
>(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), >(static_cast<TInWei*>(in_n_hi_wi_g_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(wei_g_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), static_cast<TOut*>(out_n_ho_wo_g_k_device_buf.GetDeviceBuffer()),
in_gemmk0_gemmm_gemmk1_grid_desc, in_gemmg_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmg_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc, out_gemmg_gemmm_gemmn_grid_desc,
in_gemmk0_gemmm_gemmk1_grid_step_hacks, in_gemmg_gemmk0_gemmm_gemmk1_grid_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks, wei_gemmg_gemmk0_gemmn_gemmk1_grid_step_hacks,
out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, out_gemmg_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, in_gemmg_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, wei_gemmg_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat); nrepeat);
{ {
const auto N = out_n_ho_wo_k_lengths[I0]; const auto G = wei_g_k_y_x_c_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3]; const auto N = out_n_ho_wo_g_k_lengths[I0];
const auto C = wei_k_y_x_c_lengths[I3]; const auto K = out_n_ho_wo_g_k_lengths[I4];
const auto C = wei_g_k_y_x_c_lengths[I4];
const auto Ho = out_n_ho_wo_k_lengths[I1]; const auto Ho = out_n_ho_wo_g_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2]; const auto Wo = out_n_ho_wo_g_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1]; const auto Y = wei_g_k_y_x_c_lengths[I2];
const auto X = wei_k_y_x_c_lengths[I2]; const auto X = wei_g_k_y_x_c_lengths[I3];
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / float perf = static_cast<float>((std::size_t(2) * G * N * K * Ho * Wo * C * Y * X)) /
(std::size_t(1000) * 1000 * 1000) / ave_time; (std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl; << std::endl;
} }
} }
*/
// copy result back to host // copy result back to host
out_n_ho_wo_g_k_device_buf.FromDevice(out_n_ho_wo_g_k.mData.data()); out_n_ho_wo_g_k_device_buf.FromDevice(out_n_ho_wo_g_k.mData.data());
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment