Commit ca313a29 authored by letaoqin's avatar letaoqin
Browse files

Merge branch 'develop' into dl_conv_multiple_d

parents d47bf127 8784a72e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using GNHWC = ck::tensor_layout::convolution::GNHWC;
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using GK = ck::tensor_layout::convolution::G_K;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu;
using GK_Tuple = ck::Tuple<GK>;
using GK_GK_Tuple = ck::Tuple<GK, GK>;
using I32_Tuple = ck::Tuple<int32_t>;
using F32_Tuple = ck::Tuple<float>;
using I32_F32_Tuple = ck::Tuple<int32_t, float>;
using Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<PassThrough>;
using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Relu>;
using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<PassThrough>;
using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Relu>;
using Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<PassThrough>;
using Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<Relu>;
using Add_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<PassThrough>;
using Add_Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<Relu>;
static constexpr ck::index_t NDimSpatial = 2;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
template <typename DsLayout,
typename DsDatatype,
typename OutElementOp,
ConvolutionForwardSpecialization ConvSpec>
// clang-format off
using device_conv2d_int8_instances =
std::tuple <
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>
>;
// clang-format on
// for conv + multiple of 32 bit Ds. bit of Ds will affect the ScalarPerVector of C
template <typename DsLayout,
typename DsDatatype,
typename OutElementOp,
ConvolutionForwardSpecialization ConvSpec>
// clang-format off
using device_conv2d_int8_32Ds_instances =
std::tuple <
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>
>;
// clang-format on
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "device_conv2d_xdl_int8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_conv2d_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
GK_Tuple,
GNHWK,
int8_t,
int8_t,
F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul2_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_conv2d_int8_32Ds_instances<GK_Tuple, F32_Tuple, Mul2_Clamp, ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_conv2d_int8_32Ds_instances<GK_Tuple, F32_Tuple, Mul2_Clamp, ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_conv2d_int8_32Ds_instances<GK_Tuple, F32_Tuple, Mul2_Clamp, ConvFwd1x1S1P0>{});
}
void add_device_conv2d_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
GK_Tuple,
GNHWK,
int8_t,
int8_t,
F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Relu_Mul2_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_conv2d_int8_32Ds_instances<GK_Tuple, F32_Tuple, Relu_Mul2_Clamp, ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_conv2d_int8_32Ds_instances<GK_Tuple, F32_Tuple, Relu_Mul2_Clamp, ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_conv2d_int8_32Ds_instances<GK_Tuple, F32_Tuple, Relu_Mul2_Clamp, ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "device_conv2d_xdl_int8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_conv2d_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_conv2d_int8_instances<Empty_Tuple, Empty_Tuple, Mul_Clamp, ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_conv2d_int8_instances<Empty_Tuple, Empty_Tuple, Mul_Clamp, ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_conv2d_int8_instances<Empty_Tuple, Empty_Tuple, Mul_Clamp, ConvFwd1x1S1P0>{});
}
void add_device_conv2d_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Relu_Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_conv2d_int8_instances<Empty_Tuple, Empty_Tuple, Relu_Mul_Clamp, ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_conv2d_int8_instances<Empty_Tuple, Empty_Tuple, Relu_Mul_Clamp, ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_conv2d_int8_instances<Empty_Tuple, Empty_Tuple, Relu_Mul_Clamp, ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using GNHWC = ck::tensor_layout::convolution::GNHWC;
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu;
using Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<PassThrough>;
using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Relu>;
static constexpr ck::index_t NDimSpatial = 2;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
// TODO - Add more instances
template <typename OutElementOp, ConvolutionForwardSpecialization ConvSpec>
// clang-format off
using device_conv2d_int8_instances =
std::tuple <
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>
>;
// clang-format on
void add_device_conv2d_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(instances,
device_conv2d_int8_instances<Mul_Clamp, ConvFwdDefault>{});
add_device_operation_instances(instances,
device_conv2d_int8_instances<Mul_Clamp, ConvFwd1x1P0>{});
add_device_operation_instances(instances,
device_conv2d_int8_instances<Mul_Clamp, ConvFwd1x1S1P0>{});
}
void add_device_conv2d_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Relu_Mul_Clamp>>>& instances)
{
add_device_operation_instances(instances,
device_conv2d_int8_instances<Relu_Mul_Clamp, ConvFwdDefault>{});
add_device_operation_instances(instances,
device_conv2d_int8_instances<Relu_Mul_Clamp, ConvFwd1x1P0>{});
add_device_operation_instances(instances,
device_conv2d_int8_instances<Relu_Mul_Clamp, ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
include_directories(BEFORE include_directories(BEFORE
${PROJECT_SOURCE_DIR}/ ${CMAKE_CURRENT_LIST_DIR}/include
) )
# ck_profiler add_subdirectory(src)
set(PROFILER_SOURCE
src/profiler.cpp
src/profile_gemm.cpp
src/profile_gemm_splitk.cpp
src/profile_gemm_bilinear.cpp
src/profile_gemm_bias_add_reduce.cpp
src/profile_gemm_add_add_fastgelu.cpp
src/profile_gemm_reduce.cpp
src/profile_batched_gemm.cpp
src/profile_batched_gemm_gemm.cpp
src/profile_batched_gemm_add_relu_gemm_add.cpp
src/profile_batched_gemm_reduce.cpp
src/profile_grouped_gemm.cpp
src/profile_conv_fwd.cpp
src/profile_conv_fwd_bias_relu.cpp
src/profile_conv_fwd_bias_relu_add.cpp
src/profile_conv_bwd_data.cpp
src/profile_grouped_conv_fwd.cpp
src/profile_grouped_conv_bwd_weight.cpp
src/profile_reduce.cpp
src/profile_groupnorm.cpp
src/profile_layernorm.cpp
src/profile_softmax.cpp
src/profile_batchnorm_fwd.cpp
)
add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries(ckProfiler PRIVATE utility)
target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_splitk_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bilinear_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_add_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_conv1d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_conv2d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_conv3d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv1d_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_conv3d_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_conv1d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_conv2d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_conv3d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(ckProfiler PRIVATE device_normalization_instance)
target_link_libraries(ckProfiler PRIVATE device_softmax_instance)
target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_batchnorm_instance)
rocm_install(TARGETS ckProfiler COMPONENT profiler)
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#pragma #pragma
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "profiler/include/data_type_enum.hpp" #include "profiler/data_type_enum.hpp"
namespace ck { namespace ck {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include <stdexcept>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp"
namespace ck {
namespace profiler {
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
index_t Rank,
index_t NumBatchNormReduceDim>
bool profile_batchnorm_backward_impl(bool do_verification,
int init_method,
bool do_dumpout,
bool time_kernel,
const std::vector<size_t> inOutLengths,
const std::vector<int> reduceDims,
bool haveSavedMeanInvVar,
double epsilon)
{
if(inOutLengths.size() != Rank || reduceDims.size() != NumBatchNormReduceDim)
{
throw std::runtime_error("Invalid tensor lengths or number of reduce dimensions!");
};
std::vector<size_t> scaleBiasMeanVarLengths;
// used for calculating the effective transferred bytes by each operation
size_t total_length;
size_t invariant_length = 1;
total_length =
std::accumulate(inOutLengths.begin(), inOutLengths.end(), 1, std::multiplies<size_t>{});
if(std::any_of(reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
throw std::runtime_error("Invalid reduce dimensions!");
for(int dim = 0; dim < Rank; dim++)
{
if(std::none_of(reduceDims.begin(), reduceDims.end(), [&](int d) { return dim == d; }))
{
scaleBiasMeanVarLengths.push_back(inOutLengths[dim]);
invariant_length *= inOutLengths[dim];
};
}
// input data of the batchnorm backward algorithm
Tensor<XDataType> x(inOutLengths);
Tensor<DyDataType> dy(inOutLengths);
Tensor<ScaleDataType> bnScale(scaleBiasMeanVarLengths);
Tensor<MeanVarDataType> savedMean(scaleBiasMeanVarLengths);
Tensor<MeanVarDataType> savedInvVar(scaleBiasMeanVarLengths);
// savedVariance is only used for initializing savedInvVar
Tensor<MeanVarDataType> savedVariance(scaleBiasMeanVarLengths);
// output data of the batchnorm backward algorithm
Tensor<DxDataType> dx_ref(inOutLengths);
Tensor<DxDataType> dx(inOutLengths);
Tensor<DscaleDbiasDataType> dscale(scaleBiasMeanVarLengths);
Tensor<DscaleDbiasDataType> dbias(scaleBiasMeanVarLengths);
Tensor<DscaleDbiasDataType> dscale_ref(scaleBiasMeanVarLengths);
Tensor<DscaleDbiasDataType> dbias_ref(scaleBiasMeanVarLengths);
auto inOutStrides = x.mDesc.GetStrides();
auto scaleBiasMeanVarStrides = bnScale.mDesc.GetStrides();
std::size_t num_thread = std::thread::hardware_concurrency();
if(haveSavedMeanInvVar)
{
const float x_mean = 0.0f;
const float x_stddev = 1.0f;
const float noise_stddev = 0.0001f;
// input data in normal distribution
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
// initialize the savedMean to be values with tiny variation to the mean of the x values
savedMean.GenerateTensorValue(GeneratorTensor_4<MeanVarDataType>{x_mean, noise_stddev},
num_thread);
// initialize the variance to be values with tiny variation to the variance of the x values
savedVariance.GenerateTensorValue(
GeneratorTensor_4<MeanVarDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
auto it_src = savedVariance.mData.begin();
auto it_dst = savedInvVar.mData.begin();
float tmp_epsilon = std::numeric_limits<float>::epsilon();
while(it_src != savedVariance.mData.end())
{
*it_dst = type_convert<AccDataType>(
1.0f / std::sqrtf(type_convert<float>(*it_src) + tmp_epsilon));
it_src++;
it_dst++;
};
}
else
{
const float x_mean = 0.0f;
const float x_stddev = 1.0f;
// input data in normal distribution
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
};
if(do_verification)
{
switch(init_method)
{
case 0:
dy.GenerateTensorValue(GeneratorTensor_0<DyDataType>{}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_0<ScaleDataType>{}, num_thread);
break;
case 1:
dy.GenerateTensorValue(GeneratorTensor_1<DyDataType>{1}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_1<ScaleDataType>{1}, num_thread);
break;
case 2:
dy.GenerateTensorValue(GeneratorTensor_2<DyDataType>{-2, 2}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_2<ScaleDataType>{-5, 5}, num_thread);
break;
default:
dy.GenerateTensorValue(GeneratorTensor_3<DyDataType>{-0.2f, 0.2f}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_3<ScaleDataType>{-0.5f, 0.5f}, num_thread);
}
};
// input data of the batchnorm backward algorithm
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem dy_dev(sizeof(DyDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem bnScale_dev(sizeof(ScaleDataType) * bnScale.mDesc.GetElementSpaceSize());
DeviceMem savedMean_dev(sizeof(MeanVarDataType) * savedMean.mDesc.GetElementSpaceSize());
DeviceMem savedInvVar_dev(sizeof(MeanVarDataType) * savedInvVar.mDesc.GetElementSpaceSize());
// output data of the batchnorm backward algorithm
DeviceMem dx_dev(sizeof(DxDataType) * dx.mDesc.GetElementSpaceSize());
DeviceMem dscale_dev(sizeof(DscaleDbiasDataType) * dscale.mDesc.GetElementSpaceSize());
DeviceMem dbias_dev(sizeof(DscaleDbiasDataType) * dbias.mDesc.GetElementSpaceSize());
x_dev.ToDevice(x.mData.data());
dy_dev.ToDevice(dy.mData.data());
bnScale_dev.ToDevice(bnScale.mData.data());
if(haveSavedMeanInvVar)
{
savedMean_dev.ToDevice(savedMean.mData.data());
savedInvVar_dev.ToDevice(savedInvVar.mData.data());
};
std::array<index_t, Rank> arrInOutLengths;
std::array<index_t, Rank> arrInOutStrides;
std::array<index_t, Rank - NumBatchNormReduceDim> arrScaleBiasMeanVarLengths;
std::array<index_t, Rank - NumBatchNormReduceDim> arrScaleBiasMeanVarStrides;
std::array<int, NumBatchNormReduceDim> arrReduceDims;
std::copy(inOutLengths.begin(), inOutLengths.end(), arrInOutLengths.begin());
std::copy(inOutStrides.begin(), inOutStrides.end(), arrInOutStrides.begin());
std::copy(scaleBiasMeanVarLengths.begin(),
scaleBiasMeanVarLengths.end(),
arrScaleBiasMeanVarLengths.begin());
std::copy(scaleBiasMeanVarStrides.begin(),
scaleBiasMeanVarStrides.end(),
arrScaleBiasMeanVarStrides.begin());
std::copy(reduceDims.begin(), reduceDims.end(), arrReduceDims.begin());
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
// add device batchnorm-backward instances
using DeviceOp = ck::tensor_operation::device::DeviceBatchNormBwd<XDataType,
DxDataType,
DxDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
PassThroughOp,
Rank,
NumBatchNormReduceDim>;
// get device op instances
const auto instance_ptrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << instance_ptrs.size() << " instances" << std::endl;
std::string best_instance_name;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
if(do_verification)
{
using ReferenceBatchNormBwdInstance =
ck::tensor_operation::host::ReferenceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
PassThroughOp,
Rank,
NumBatchNormReduceDim>;
auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{};
auto argument_ptr_ref = batchNormBwd_ref.MakeArgumentPointer(
arrInOutLengths,
arrInOutStrides,
arrInOutStrides,
arrInOutStrides,
arrReduceDims,
arrScaleBiasMeanVarLengths,
arrScaleBiasMeanVarStrides,
arrScaleBiasMeanVarStrides,
arrScaleBiasMeanVarStrides,
x.mData.data(),
dy.mData.data(),
bnScale.mData.data(),
haveSavedMeanInvVar ? savedMean.mData.data() : nullptr,
haveSavedMeanInvVar ? savedInvVar.mData.data() : nullptr,
epsilon,
PassThroughOp{},
dx_ref.mData.data(),
dscale_ref.mData.data(),
dbias_ref.mData.data());
if(!batchNormBwd_ref.IsSupportedArgument(argument_ptr_ref.get()))
{
std::cout << "The runtime parameters not supported by the reference instance, exiting!"
<< std::endl;
return (false);
};
auto invoker_ptr_ref = batchNormBwd_ref.MakeInvokerPointer();
(void)invoker_ptr_ref->Run(argument_ptr_ref.get());
}
int num_kernel = 0;
bool pass = true;
for(auto& inst_ptr : instance_ptrs)
{
auto argument_ptr = inst_ptr->MakeArgumentPointer(
arrInOutLengths,
arrInOutStrides,
arrInOutStrides,
arrInOutStrides,
arrReduceDims,
arrScaleBiasMeanVarLengths,
arrScaleBiasMeanVarStrides,
arrScaleBiasMeanVarStrides,
arrScaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(),
dy_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(),
haveSavedMeanInvVar ? savedMean_dev.GetDeviceBuffer() : nullptr,
haveSavedMeanInvVar ? savedInvVar_dev.GetDeviceBuffer() : nullptr,
epsilon,
PassThroughOp{},
dx_dev.GetDeviceBuffer(),
dscale_dev.GetDeviceBuffer(),
dbias_dev.GetDeviceBuffer());
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
num_kernel++;
}
else
{
if(time_kernel)
{
std::cout << inst_ptr->GetTypeString()
<< " skipped due to unsupported argument: " << std::endl;
}
continue;
};
size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = inst_ptr->MakeInvokerPointer();
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
size_t num_bytes = 0;
// inputing of x, dy, scale, outputing of dx, dscale, dbias
num_bytes += total_length * (sizeof(XDataType) + sizeof(DyDataType) + sizeof(DxDataType)) +
invariant_length * sizeof(DscaleDbiasDataType) * 2;
// inputting of savedMean, savedInvVariance
if(haveSavedMeanInvVar)
num_bytes += invariant_length * sizeof(MeanVarDataType) * 2;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
if(time_kernel)
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, "
<< inst_ptr->GetTypeString() << std::endl;
if(avg_time < best_avg_time)
{
best_instance_name = inst_ptr->GetTypeString();
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
using ck::utils::check_err;
bool single_pass = true;
dx_dev.FromDevice(dx.mData.data());
dscale_dev.FromDevice(dscale.data());
dbias_dev.FromDevice(dbias.data());
// clang-format off
single_pass = single_pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:", 5e-4, 5e-4);
single_pass = single_pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 3e-3, 3e-3);
single_pass = single_pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 3e-3, 3e-3);
// clang-format on
pass = pass && single_pass;
};
if(do_dumpout)
{
using ck::host_common::dumpBufferToFile;
// clang-format off
dumpBufferToFile("dump_x.bin", x.mData.data(), x.mDesc.GetElementSize());
dumpBufferToFile("dump_dy.bin", dy.mData.data(), dy.mDesc.GetElementSize());
dumpBufferToFile("dump_dx.bin", dx.mData.data(), dx.mDesc.GetElementSize());
dumpBufferToFile("dump_dx_ref.bin", dx_ref.mData.data(), dx_ref.mDesc.GetElementSize());
dumpBufferToFile("dump_dscale.bin", dscale.mData.data(), dscale.mDesc.GetElementSize());
dumpBufferToFile("dump_dscale_ref.bin", dscale_ref.mData.data(), dscale_ref.mDesc.GetElementSize());
// clang-format off
};
}
if(time_kernel)
{
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_instance_name << std::endl;
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
return pass;
}
} // namespace profiler
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment