"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "e15769f9f2d91b9ce2cc4f4448c7112676e63afa"
Commit 7b002f23 authored by Jing Zhang's avatar Jing Zhang
Browse files

add v4r4 xdlops

parent 87a75734
......@@ -27,9 +27,6 @@
#include "amd_inline_asm.hpp"
#endif
#if CK_USE_AMD_XDLOPS
#include "amd_xdlops.hpp"
#include "amd_xdlops_inline_asm.hpp"
#endif
#endif
#include "common_header.hpp"
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "float_types.h"
template <class T,
class InDesc,
......@@ -10,44 +12,47 @@ template <class T,
class ConvDilations,
class InLeftPads,
class InRightPads>
void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ck::index_t nrepeat)
void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ck::index_t nrepeat)
{
using namespace ck;
// read params: problem description
constexpr index_t G = CK_PARAM_PROBLEM_G;
constexpr index_t N = CK_PARAM_PROBLEM_N;
constexpr index_t K = CK_PARAM_PROBLEM_K;
constexpr index_t C = CK_PARAM_PROBLEM_C;
constexpr index_t Hi = CK_PARAM_PROBLEM_HI;
constexpr index_t Wi = CK_PARAM_PROBLEM_WI;
constexpr index_t Ho = CK_PARAM_PROBLEM_HO;
constexpr index_t Wo = CK_PARAM_PROBLEM_WO;
constexpr index_t Y = CK_PARAM_PROBLEM_Y;
constexpr index_t X = CK_PARAM_PROBLEM_X;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc =
make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides());
constexpr auto wei_kcyx_desc =
make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides());
constexpr auto out_nkhw_desc =
make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides());
constexpr index_t ConvStrideH = CK_PARAM_PROBLEM_CONV_STRIDE_H;
constexpr index_t ConvStrideW = CK_PARAM_PROBLEM_CONV_STRIDE_W;
// read params: problem description
constexpr index_t G = 1;
constexpr index_t ConvDilationH = CK_PARAM_PROBLEM_CONV_DILATION_H;
constexpr index_t ConvDilationW = CK_PARAM_PROBLEM_CONV_DILATION_W;
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t K = out_nkhw_desc.GetLength(I1);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t InLeftPadH = CK_PARAM_PROBLEM_IN_LEFT_PAD_H;
constexpr index_t InLeftPadW = CK_PARAM_PROBLEM_IN_LEFT_PAD_W;
constexpr index_t C = in_nchw_desc.GetLength(I1);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t InRightPadH = CK_PARAM_PROBLEM_IN_RIGHT_PAD_H;
constexpr index_t InRightPadW = CK_PARAM_PROBLEM_IN_RIGHT_PAD_W;
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
constexpr auto CPerGroup = C / G;
......@@ -58,31 +63,27 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr auto out_n_k_ho_wo_desc =
make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
using ConvStrides = Sequence<ConvStrideH, ConvStrideW>;
using ConvDilations = Sequence<ConvDilationH, ConvDilationW>;
using InLeftPads = Sequence<InLeftPadH, InLeftPadW>;
using InRightPads = Sequence<InRightPadH, InRightPadW>;
// read params: tunning parameters
constexpr index_t GemmMPerBlock = CK_PARAM_TUNABLE_GEMM_M_PER_BLOCK;
constexpr index_t GemmNPerBlock = CK_PARAM_TUNABLE_GEMM_N_PER_BLOCK;
constexpr index_t GemmKPerBlock = CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK;
constexpr index_t GemmMPerWave = CK_PARAM_TUNABLE_GEMM_M_PER_WAVE;
constexpr index_t GemmNPerWave = CK_PARAM_TUNABLE_GEMM_N_PER_WAVE;
constexpr index_t GemmKPack = CK_PARAM_TUNABLE_GEMM_KPACK;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 1;
// read params: dependent parameters
constexpr index_t BlockSize = CK_PARAM_DEPENDENT_BLOCK_SIZE;
constexpr index_t GridSize = CK_PARAM_DEPENDENT_GRID_SIZE;
constexpr index_t BlockSize = 256;
constexpr index_t GemmM = K;
constexpr index_t GemmN = N * Ho * Wo;
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock);
// A matrix copy
constexpr index_t GemmABlockCopyClusterLengths_GemmK =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K;
constexpr index_t GemmABlockCopyClusterLengths_GemmM =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M;
constexpr index_t GemmABlockCopyClusterLengths_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_KPACK;
constexpr index_t GemmABlockCopyClusterLengths_GemmK = 4;
constexpr index_t GemmABlockCopyClusterLengths_GemmM = 64;
constexpr index_t GemmABlockCopyClusterLengths_GemmKPack = 1;
constexpr index_t GemmABlockCopyThreadSliceLengths_GemmK =
GemmKPerBlock / GemmABlockCopyClusterLengths_GemmK;
......@@ -107,19 +108,13 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
using GemmABlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [GemmG, GemmM, GemmK, GemmKPack]
using GemmABlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [GemmG, GemmK, GemmM, GemmKPack]
constexpr index_t GemmABlockCopySrcDataPerRead_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_KPACK;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmKPack = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmKPack = 1;
// B matrix Copy
constexpr index_t GemmBBlockCopyClusterLengths_GemmK =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K;
constexpr index_t GemmBBlockCopyClusterLengths_GemmN =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N;
constexpr index_t GemmBBlockCopyClusterLengths_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_KPACK;
constexpr index_t GemmBBlockCopyClusterLengths_GemmK = 4;
constexpr index_t GemmBBlockCopyClusterLengths_GemmN = 64;
constexpr index_t GemmBBlockCopyClusterLengths_GemmKPack = 1;
constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmK =
GemmKPerBlock / GemmBBlockCopyClusterLengths_GemmK;
......@@ -144,50 +139,90 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
using GemmBBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [GemmG, GemmK, GemmKPack, GemmN]
using GemmBBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [GemmG, GemmK, GemmN, GemmKPack]
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmKPack = 1;
// gridwise GEMM
constexpr auto wkgrp_schd_order = NBlock1MBlock0;
constexpr auto gridwise_conv =
GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw<
GridSize,
BlockSize,
FLOAT, // Input data type
FLOAT_ACCUM, // Acc data type
FLOAT, // Ouput data type
decltype(in_n_c_hi_wi_desc),
decltype(wei_k_cpergroup_y_x_desc),
decltype(out_n_k_ho_wo_desc),
G,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmKPack,
GemmABlockCopySubLengths_GemmG_GemmK_GemmM_GemmKPack,
GemmABlockCopyClusterLengths_GemmG_GemmK_GemmM_GemmKPack,
GemmABlockCopyThreadClusterArrangeOrder,
GemmABlockCopySrcAccessOrder,
GemmABlockCopyDstAccessOrder,
GemmABlockCopySrcDataPerRead_GemmKPack,
GemmABlockCopyDstDataPerWrite_GemmKPack,
GemmBBlockCopySubLengths_GemmG_GemmK_GemmN_GemmKPack,
GemmBBlockCopyClusterLengths_GemmG_GemmK_GemmN_GemmKPack,
GemmBBlockCopyThreadClusterArrangeOrder,
GemmBBlockCopySrcAccessOrder,
GemmBBlockCopyDstAccessOrder,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmKPack,
wkgrp_schd_order>{};
gridwise_conv.Run(p_in_global, p_wei_global, p_out_global);
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
using gridwise_conv = GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw<
GridSize,
BlockSize,
TDevice, // Input data type
TDevice, // Acc data type
TDevice, // Ouput data type
decltype(in_n_c_hi_wi_desc),
decltype(wei_k_cpergroup_y_x_desc),
decltype(out_n_k_ho_wo_desc),
G,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmKPack,
GemmABlockCopySubLengths_GemmG_GemmK_GemmM_GemmKPack,
GemmABlockCopyClusterLengths_GemmG_GemmK_GemmM_GemmKPack,
GemmABlockCopyThreadClusterArrangeOrder,
GemmABlockCopySrcAccessOrder,
GemmABlockCopyDstAccessOrder,
GemmABlockCopySrcDataPerRead_GemmKPack,
GemmABlockCopyDstDataPerWrite_GemmKPack,
GemmBBlockCopySubLengths_GemmG_GemmK_GemmN_GemmKPack,
GemmBBlockCopyClusterLengths_GemmG_GemmK_GemmN_GemmKPack,
GemmBBlockCopyThreadClusterArrangeOrder,
GemmBBlockCopySrcAccessOrder,
GemmBBlockCopyDstAccessOrder,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmKPack,
wkgrp_schd_order>;
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
launch_kernel(run_gridwise_operation<gridwise_conv,
const TDevice* const __restrict__,
const TDevice* const __restrict__,
TDevice* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<TDevice*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
......@@ -13,6 +13,7 @@
#include "device_tensor.hpp"
#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dummy_static_transform.hpp"
#include "device_dummy_dynamic_transform_v1.hpp"
......@@ -111,17 +112,17 @@ int main(int argc, char* argv[])
RightPads{},
nrepeat);
#elif 1
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment