Commit 7b002f23 authored by Jing Zhang's avatar Jing Zhang
Browse files

add v4r4 xdlops

parent 87a75734
...@@ -27,9 +27,6 @@ ...@@ -27,9 +27,6 @@
#include "amd_inline_asm.hpp" #include "amd_inline_asm.hpp"
#endif #endif
#if CK_USE_AMD_XDLOPS
#include "amd_xdlops.hpp" #include "amd_xdlops.hpp"
#include "amd_xdlops_inline_asm.hpp"
#endif
#endif #endif
#include "common_header.hpp" #include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" #include "gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "float_types.h"
template <class T, template <class T,
class InDesc, class InDesc,
...@@ -10,8 +12,8 @@ template <class T, ...@@ -10,8 +12,8 @@ template <class T,
class ConvDilations, class ConvDilations,
class InLeftPads, class InLeftPads,
class InRightPads> class InRightPads>
void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
(InDesc, InDesc,
const Tensor<T>& in_nchw, const Tensor<T>& in_nchw,
WeiDesc, WeiDesc,
const Tensor<T>& wei_kcyx, const Tensor<T>& wei_kcyx,
...@@ -25,29 +27,32 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -25,29 +27,32 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
{ {
using namespace ck; using namespace ck;
// read params: problem description constexpr auto I0 = Number<0>{};
constexpr index_t G = CK_PARAM_PROBLEM_G; constexpr auto I1 = Number<1>{};
constexpr index_t N = CK_PARAM_PROBLEM_N; constexpr auto I2 = Number<2>{};
constexpr index_t K = CK_PARAM_PROBLEM_K; constexpr auto I3 = Number<3>{};
constexpr index_t C = CK_PARAM_PROBLEM_C;
constexpr index_t Hi = CK_PARAM_PROBLEM_HI; constexpr auto in_nchw_desc =
constexpr index_t Wi = CK_PARAM_PROBLEM_WI; make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides());
constexpr index_t Ho = CK_PARAM_PROBLEM_HO; constexpr auto wei_kcyx_desc =
constexpr index_t Wo = CK_PARAM_PROBLEM_WO; make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides());
constexpr index_t Y = CK_PARAM_PROBLEM_Y; constexpr auto out_nkhw_desc =
constexpr index_t X = CK_PARAM_PROBLEM_X; make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides());
constexpr index_t ConvStrideH = CK_PARAM_PROBLEM_CONV_STRIDE_H; // read params: problem description
constexpr index_t ConvStrideW = CK_PARAM_PROBLEM_CONV_STRIDE_W; constexpr index_t G = 1;
constexpr index_t ConvDilationH = CK_PARAM_PROBLEM_CONV_DILATION_H; constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t ConvDilationW = CK_PARAM_PROBLEM_CONV_DILATION_W; constexpr index_t K = out_nkhw_desc.GetLength(I1);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t InLeftPadH = CK_PARAM_PROBLEM_IN_LEFT_PAD_H; constexpr index_t C = in_nchw_desc.GetLength(I1);
constexpr index_t InLeftPadW = CK_PARAM_PROBLEM_IN_LEFT_PAD_W; constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t InRightPadH = CK_PARAM_PROBLEM_IN_RIGHT_PAD_H; constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t InRightPadW = CK_PARAM_PROBLEM_IN_RIGHT_PAD_W; constexpr index_t X = wei_kcyx_desc.GetLength(I3);
constexpr auto CPerGroup = C / G; constexpr auto CPerGroup = C / G;
...@@ -58,31 +63,27 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -58,31 +63,27 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr auto out_n_k_ho_wo_desc = constexpr auto out_n_k_ho_wo_desc =
make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{}); make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
using ConvStrides = Sequence<ConvStrideH, ConvStrideW>;
using ConvDilations = Sequence<ConvDilationH, ConvDilationW>;
using InLeftPads = Sequence<InLeftPadH, InLeftPadW>;
using InRightPads = Sequence<InRightPadH, InRightPadW>;
// read params: tunning parameters // read params: tunning parameters
constexpr index_t GemmMPerBlock = CK_PARAM_TUNABLE_GEMM_M_PER_BLOCK; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = CK_PARAM_TUNABLE_GEMM_N_PER_BLOCK; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerWave = CK_PARAM_TUNABLE_GEMM_M_PER_WAVE; constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = CK_PARAM_TUNABLE_GEMM_N_PER_WAVE; constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = CK_PARAM_TUNABLE_GEMM_KPACK; constexpr index_t GemmKPack = 1;
// read params: dependent parameters // read params: dependent parameters
constexpr index_t BlockSize = CK_PARAM_DEPENDENT_BLOCK_SIZE; constexpr index_t BlockSize = 256;
constexpr index_t GridSize = CK_PARAM_DEPENDENT_GRID_SIZE;
constexpr index_t GemmM = K;
constexpr index_t GemmN = N * Ho * Wo;
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock);
// A matrix copy // A matrix copy
constexpr index_t GemmABlockCopyClusterLengths_GemmK = constexpr index_t GemmABlockCopyClusterLengths_GemmK = 4;
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K; constexpr index_t GemmABlockCopyClusterLengths_GemmM = 64;
constexpr index_t GemmABlockCopyClusterLengths_GemmM = constexpr index_t GemmABlockCopyClusterLengths_GemmKPack = 1;
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M;
constexpr index_t GemmABlockCopyClusterLengths_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_KPACK;
constexpr index_t GemmABlockCopyThreadSliceLengths_GemmK = constexpr index_t GemmABlockCopyThreadSliceLengths_GemmK =
GemmKPerBlock / GemmABlockCopyClusterLengths_GemmK; GemmKPerBlock / GemmABlockCopyClusterLengths_GemmK;
...@@ -107,19 +108,13 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -107,19 +108,13 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
using GemmABlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [GemmG, GemmM, GemmK, GemmKPack] using GemmABlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [GemmG, GemmM, GemmK, GemmKPack]
using GemmABlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [GemmG, GemmK, GemmM, GemmKPack] using GemmABlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [GemmG, GemmK, GemmM, GemmKPack]
constexpr index_t GemmABlockCopySrcDataPerRead_GemmKPack = constexpr index_t GemmABlockCopySrcDataPerRead_GemmKPack = 1;
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_KPACK; constexpr index_t GemmABlockCopyDstDataPerWrite_GemmKPack = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK;
// B matrix Copy // B matrix Copy
constexpr index_t GemmBBlockCopyClusterLengths_GemmK = constexpr index_t GemmBBlockCopyClusterLengths_GemmK = 4;
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K; constexpr index_t GemmBBlockCopyClusterLengths_GemmN = 64;
constexpr index_t GemmBBlockCopyClusterLengths_GemmN = constexpr index_t GemmBBlockCopyClusterLengths_GemmKPack = 1;
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N;
constexpr index_t GemmBBlockCopyClusterLengths_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_KPACK;
constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmK = constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmK =
GemmKPerBlock / GemmBBlockCopyClusterLengths_GemmK; GemmKPerBlock / GemmBBlockCopyClusterLengths_GemmK;
...@@ -144,22 +139,20 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -144,22 +139,20 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
using GemmBBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [GemmG, GemmK, GemmKPack, GemmN] using GemmBBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [GemmG, GemmK, GemmKPack, GemmN]
using GemmBBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [GemmG, GemmK, GemmN, GemmKPack] using GemmBBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [GemmG, GemmK, GemmN, GemmKPack]
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmKPack = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK;
// gridwise GEMM // gridwise GEMM
constexpr auto wkgrp_schd_order = NBlock1MBlock0; constexpr auto wkgrp_schd_order = NBlock1MBlock0;
constexpr auto gridwise_conv = using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw<
using gridwise_conv = GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw<
GridSize, GridSize,
BlockSize, BlockSize,
FLOAT, // Input data type TDevice, // Input data type
FLOAT_ACCUM, // Acc data type TDevice, // Acc data type
FLOAT, // Ouput data type TDevice, // Ouput data type
decltype(in_n_c_hi_wi_desc), decltype(in_n_c_hi_wi_desc),
decltype(wei_k_cpergroup_y_x_desc), decltype(wei_k_cpergroup_y_x_desc),
decltype(out_n_k_ho_wo_desc), decltype(out_n_k_ho_wo_desc),
...@@ -188,6 +181,48 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -188,6 +181,48 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
GemmBBlockCopyDstAccessOrder, GemmBBlockCopyDstAccessOrder,
GemmBBlockCopySrcDataPerRead_GemmN, GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmKPack, GemmBBlockCopyDstDataPerWrite_GemmKPack,
wkgrp_schd_order>{}; wkgrp_schd_order>;
gridwise_conv.Run(p_in_global, p_wei_global, p_out_global);
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
launch_kernel(run_gridwise_operation<gridwise_conv,
const TDevice* const __restrict__,
const TDevice* const __restrict__,
TDevice* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<TDevice*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
} }
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dummy_static_transform.hpp" #include "device_dummy_static_transform.hpp"
#include "device_dummy_dynamic_transform_v1.hpp" #include "device_dummy_dynamic_transform_v1.hpp"
...@@ -111,7 +112,7 @@ int main(int argc, char* argv[]) ...@@ -111,7 +112,7 @@ int main(int argc, char* argv[])
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 1
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment