Commit 7a4f83e0 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Replace multiD with multiABD

parent 95479e67
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
using InDataType = int8_t; using InDataType = int8_t;
using WeiDataType = int8_t; using WeiDataType = int8_t;
...@@ -33,7 +33,7 @@ template <ck::index_t NDimSpatial, ...@@ -33,7 +33,7 @@ template <ck::index_t NDimSpatial,
typename RequantScaleLayout, typename RequantScaleLayout,
typename OutLayout> typename OutLayout>
using DeviceGroupedConvNDFwdInstance = using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial, NDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
using InDataType = int8_t; using InDataType = int8_t;
using WeiDataType = int8_t; using WeiDataType = int8_t;
...@@ -31,7 +31,7 @@ template <ck::index_t NDimSpatial, ...@@ -31,7 +31,7 @@ template <ck::index_t NDimSpatial,
typename BiasLayout, typename BiasLayout,
typename OutLayout> typename OutLayout>
using DeviceGroupedConvNDFwdInstance = using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial, NDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
using InDataType = int8_t; using InDataType = int8_t;
using WeiDataType = int8_t; using WeiDataType = int8_t;
...@@ -31,7 +31,7 @@ template <ck::index_t NDimSpatial, ...@@ -31,7 +31,7 @@ template <ck::index_t NDimSpatial,
typename RequantScaleLayout, typename RequantScaleLayout,
typename OutLayout> typename OutLayout>
using DeviceGroupedConvNDFwdInstance = using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial, NDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
using InDataType = int8_t; using InDataType = int8_t;
using WeiDataType = int8_t; using WeiDataType = int8_t;
...@@ -26,7 +26,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -26,7 +26,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout> template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
using DeviceGroupedConvNDFwdInstance = using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial, NDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -47,7 +47,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -47,7 +47,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
template <typename OutElementOp> template <typename OutElementOp>
using DeviceGroupedConvNDFwdInstance = using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial, NDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -47,7 +47,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -47,7 +47,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
template <typename OutElementOp> template <typename OutElementOp>
using DeviceGroupedConvNDFwdInstance = using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial, NDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -44,7 +44,7 @@ template <typename DataType, ...@@ -44,7 +44,7 @@ template <typename DataType,
typename InElementOp, typename InElementOp,
typename WeiElementOp> typename WeiElementOp>
using DeviceGroupedConvNDMultiABFwdInstance = using DeviceGroupedConvNDMultiABFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial, NDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp"
...@@ -216,18 +216,18 @@ template <index_t NDimSpatial, ...@@ -216,18 +216,18 @@ template <index_t NDimSpatial,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector> index_t CThreadTransferDstScalarPerVector>
struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
: public DeviceGroupedConvFwdMultipleD<NDimSpatial, : public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
ALayout, ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
ELayout, ELayout,
ADataType, ADataType,
BDataType, BDataType,
DsDataType, DsDataType,
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation> CDEElementwiseOperation>
{ {
using DeviceOp = DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK; using DeviceOp = DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK;
......
...@@ -1090,7 +1090,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -1090,7 +1090,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle" str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
...@@ -92,18 +92,18 @@ template <index_t NDimSpatial, ...@@ -92,18 +92,18 @@ template <index_t NDimSpatial,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1> ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
: public DeviceGroupedConvFwdMultipleD<NDimSpatial, : public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
ALayout, ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
ELayout, ELayout,
ADataType, ADataType,
BDataType, BDataType,
DsDataType, DsDataType,
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation> CDEElementwiseOperation>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD_Wmma_CShuffle; using DeviceOp = DeviceGroupedConvFwdMultipleD_Wmma_CShuffle;
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <memory> #include <memory>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -24,66 +24,66 @@ using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; ...@@ -24,66 +24,66 @@ using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv3d forward multi AB scaleadd, NDHWGC/GKZYXC/NDHWGK // grouped conv3d forward multi AB scaleadd, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<>, ck::Tuple<>,
NDHWGK, NDHWGK,
ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>,
ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>,
ck::Tuple<>, ck::Tuple<>,
BF16, BF16,
ScaleAdd, ScaleAdd,
ScaleAdd, ScaleAdd,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances( void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<>, ck::Tuple<>,
NDHWGK, NDHWGK,
ck::Tuple<F16, F16>, ck::Tuple<F16, F16>,
ck::Tuple<F16, F16>, ck::Tuple<F16, F16>,
ck::Tuple<>, ck::Tuple<>,
F16, F16,
ScaleAdd, ScaleAdd,
ScaleAdd, ScaleAdd,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instances( void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<>, ck::Tuple<>,
NDHWGK, NDHWGK,
ck::Tuple<F32, F32>, ck::Tuple<F32, F32>,
ck::Tuple<F32, F32>, ck::Tuple<F32, F32>,
ck::Tuple<>, ck::Tuple<>,
F32, F32,
ScaleAdd, ScaleAdd,
ScaleAdd, ScaleAdd,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_INT8 #ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instances( void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<>, ck::Tuple<>,
NDHWGK, NDHWGK,
ck::Tuple<int8_t, int8_t>, ck::Tuple<int8_t, int8_t>,
ck::Tuple<int8_t, int8_t>, ck::Tuple<int8_t, int8_t>,
ck::Tuple<>, ck::Tuple<>,
int8_t, int8_t,
ScaleAdd, ScaleAdd,
ScaleAdd, ScaleAdd,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
...@@ -96,7 +96,7 @@ template <ck::index_t NumDimSpatial, ...@@ -96,7 +96,7 @@ template <ck::index_t NumDimSpatial,
typename DDataTypes, typename DDataTypes,
typename OutDataType, typename OutDataType,
typename ComputeType> typename ComputeType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -111,19 +111,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -111,19 +111,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ComputeType>> ComputeType>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
DLayouts, WeiLayout,
OutLayout, DLayouts,
InDataType, OutLayout,
WeiDataType, InDataType,
DDataTypes, WeiDataType,
OutDataType, DDataTypes,
ck::tensor_operation::element_wise::ScaleAdd, OutDataType,
ck::tensor_operation::element_wise::ScaleAdd, ck::tensor_operation::element_wise::ScaleAdd,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::ScaleAdd,
ComputeType>; ck::tensor_operation::element_wise::PassThrough,
ComputeType>;
static auto GetInstances() static auto GetInstances()
{ {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <memory> #include <memory>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -24,66 +24,66 @@ using ScaleAddScaleAddRelu = ck::tensor_operation::element_wise::ScaleAddScaleAd ...@@ -24,66 +24,66 @@ using ScaleAddScaleAddRelu = ck::tensor_operation::element_wise::ScaleAddScaleAd
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances( void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK, NDHWGK,
BF16, BF16,
BF16, BF16,
ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
ScaleAddScaleAddRelu>>>& instances); ScaleAddScaleAddRelu>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances( void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK, NDHWGK,
F16, F16,
F16, F16,
ck::Tuple<F16, F16>, ck::Tuple<F16, F16>,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
ScaleAddScaleAddRelu>>>& instances); ScaleAddScaleAddRelu>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances( void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK, NDHWGK,
F32, F32,
F32, F32,
ck::Tuple<F32, F32>, ck::Tuple<F32, F32>,
F32, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
ScaleAddScaleAddRelu>>>& instances); ScaleAddScaleAddRelu>>>& instances);
#endif #endif
#ifdef CK_ENABLE_INT8 #ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instances( void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, NDHWGK>,
NDHWGK, NDHWGK,
int8_t, int8_t,
int8_t, int8_t,
ck::Tuple<F32, F32>, ck::Tuple<F32, F32>,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
ScaleAddScaleAddRelu>>>& instances); ScaleAddScaleAddRelu>>>& instances);
#endif #endif
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
...@@ -96,7 +96,7 @@ template <ck::index_t NumDimSpatial, ...@@ -96,7 +96,7 @@ template <ck::index_t NumDimSpatial,
typename DDataTypes, typename DDataTypes,
typename OutDataType, typename OutDataType,
typename ComputeType> typename ComputeType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -112,19 +112,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -112,19 +112,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ComputeType>> ComputeType>>
{ {
using DeviceOp = using DeviceOp =
DeviceGroupedConvFwdMultipleD<NumDimSpatial, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
DLayouts, DLayouts,
OutLayout, OutLayout,
InDataType, InDataType,
WeiDataType, WeiDataType,
DDataTypes, DDataTypes,
OutDataType, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::ScaleAddScaleAddRelu, ck::tensor_operation::element_wise::ScaleAddScaleAddRelu,
ComputeType>; ComputeType>;
static auto GetInstances() static auto GetInstances()
{ {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -20,96 +20,96 @@ namespace instance { ...@@ -20,96 +20,96 @@ namespace instance {
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances( void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul2_Clamp<PassThrough>>>>& Add_Activation_Mul2_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances( void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul2_Clamp<Relu>>>>& Add_Activation_Mul2_Clamp<Relu>>>>&
instances); instances);
void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances( void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Mul2_Activation_Mul_Clamp<TanH>>>>& Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
instances); instances);
#endif #endif
void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul2_Clamp<PassThrough>>>>& Add_Activation_Mul2_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul2_Clamp<Relu>>>>& Add_Activation_Mul2_Clamp<Relu>>>>&
instances); instances);
void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Mul2_Activation_Mul_Clamp<TanH>>>>& Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
instances); instances);
// piecewise activation function // piecewise activation function
...@@ -123,7 +123,7 @@ template <ck::index_t NumDimSpatial, ...@@ -123,7 +123,7 @@ template <ck::index_t NumDimSpatial,
typename DsDataType, typename DsDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -137,18 +137,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -137,18 +137,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul2_Clamp<Activation>>> Add_Activation_Mul2_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
DsLayout, WeiLayout,
OutLayout, DsLayout,
InDataType, OutLayout,
WeiDataType, InDataType,
DsDataType, WeiDataType,
OutDataType, DsDataType,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul2_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul2_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -193,7 +194,7 @@ template <ck::index_t NumDimSpatial, ...@@ -193,7 +194,7 @@ template <ck::index_t NumDimSpatial,
typename DsDataType, typename DsDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -207,18 +208,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -207,18 +208,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Mul2_Activation_Mul_Clamp<Activation>>> Add_Mul2_Activation_Mul_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
DsLayout, WeiLayout,
OutLayout, DsLayout,
InDataType, OutLayout,
WeiDataType, InDataType,
DsDataType, WeiDataType,
OutDataType, DsDataType,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Mul2_Activation_Mul_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Add_Mul2_Activation_Mul_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -20,94 +20,96 @@ namespace instance { ...@@ -20,94 +20,96 @@ namespace instance {
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances( void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul_Clamp<PassThrough>>>>& Add_Activation_Mul_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances( void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul_Clamp<Relu>>>>& Add_Activation_Mul_Clamp<Relu>>>>&
instances); instances);
void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances( void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
GK_Tuple, GKYXC,
NHWGK, GK_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
I32_Tuple, int8_t,
int8_t, I32_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>& PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>&
instances); instances);
#endif #endif
void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul_Clamp<PassThrough>>>>& Add_Activation_Mul_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul_Clamp<Relu>>>>& Add_Activation_Mul_Clamp<Relu>>>>&
instances); instances);
void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
GK_Tuple, GKYXC,
NHWGK, GK_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
I32_Tuple, int8_t,
int8_t, I32_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>& PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>&
instances); instances);
// piecewise activation function // piecewise activation function
...@@ -121,7 +123,7 @@ template <ck::index_t NumDimSpatial, ...@@ -121,7 +123,7 @@ template <ck::index_t NumDimSpatial,
typename DsDataType, typename DsDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -135,18 +137,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -135,18 +137,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul_Clamp<Activation>>> Add_Activation_Mul_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
DsLayout, WeiLayout,
OutLayout, DsLayout,
InDataType, OutLayout,
WeiDataType, InDataType,
DsDataType, WeiDataType,
OutDataType, DsDataType,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -191,7 +194,7 @@ template <ck::index_t NumDimSpatial, ...@@ -191,7 +194,7 @@ template <ck::index_t NumDimSpatial,
typename DsDataType, typename DsDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -205,18 +208,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -205,18 +208,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Mul_Activation_Mul_Clamp<Activation>>> Add_Mul_Activation_Mul_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
DsLayout, WeiLayout,
OutLayout, DsLayout,
InDataType, OutLayout,
WeiDataType, InDataType,
DsDataType, WeiDataType,
OutDataType, DsDataType,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Mul_Activation_Mul_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Add_Mul_Activation_Mul_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -19,63 +19,65 @@ namespace instance { ...@@ -19,63 +19,65 @@ namespace instance {
#ifdef DL_KERNELS #ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_perchannel_quantization_int8_instances( void add_device_conv2d_dl_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
GK_Tuple, GKYXC,
NHWGK, GK_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
F32_Tuple, int8_t,
int8_t, F32_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Activation_Mul2_Clamp<PassThrough>>>>& PassThrough,
Activation_Mul2_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances( void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
F32_Tuple, F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Activation_Mul2_Clamp<Relu>>>>& Activation_Mul2_Clamp<Relu>>>>&
instances); instances);
#endif #endif
void add_device_conv2d_xdl_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
GK_Tuple, GKYXC,
NHWGK, GK_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
F32_Tuple, int8_t,
int8_t, F32_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Activation_Mul2_Clamp<PassThrough>>>>& PassThrough,
Activation_Mul2_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
F32_Tuple, F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Activation_Mul2_Clamp<Relu>>>>& Activation_Mul2_Clamp<Relu>>>>&
instances); instances);
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
...@@ -88,7 +90,7 @@ template <ck::index_t NumDimSpatial, ...@@ -88,7 +90,7 @@ template <ck::index_t NumDimSpatial,
typename DsDataType, typename DsDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -102,18 +104,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -102,18 +104,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Activation_Mul2_Clamp<Activation>>> Activation_Mul2_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
GK_Tuple, WeiLayout,
OutLayout, GK_Tuple,
InDataType, OutLayout,
WeiDataType, InDataType,
F32_Tuple, WeiDataType,
OutDataType, F32_Tuple,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Activation_Mul2_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Activation_Mul2_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -19,63 +19,65 @@ namespace instance { ...@@ -19,63 +19,65 @@ namespace instance {
#ifdef DL_KERNELS #ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_perlayer_quantization_int8_instances( void add_device_conv2d_dl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
Empty_Tuple, int8_t,
int8_t, Empty_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Activation_Mul_Clamp<PassThrough>>>>& PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances( void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Activation_Mul_Clamp<Relu>>>>& Activation_Mul_Clamp<Relu>>>>&
instances); instances);
#endif #endif
void add_device_conv2d_xdl_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
Empty_Tuple, int8_t,
int8_t, Empty_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Activation_Mul_Clamp<PassThrough>>>>& PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Activation_Mul_Clamp<Relu>>>>& Activation_Mul_Clamp<Relu>>>>&
instances); instances);
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
...@@ -86,7 +88,7 @@ template <ck::index_t NumDimSpatial, ...@@ -86,7 +88,7 @@ template <ck::index_t NumDimSpatial,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -100,18 +102,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -100,18 +102,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Activation_Mul_Clamp<Activation>>> Activation_Mul_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
Empty_Tuple, WeiLayout,
OutLayout, Empty_Tuple,
InDataType, OutLayout,
WeiDataType, InDataType,
Empty_Tuple, WeiDataType,
OutDataType, Empty_Tuple,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Activation_Mul_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Activation_Mul_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment