"docs/vscode:/vscode.git/clone" did not exist on "e04e401613bf13a235453d1f9ab06ef8bf4cf7fe"
Unverified Commit 38470e04 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Add client example of grouped conv2d backward weight (data type: fp16) (#498)

* Remove redundant CMake setting

* Extract common code from files

* Rename folder 'convnd' to 'conv'

* Use std::array<> to accept compile-time kwnown # of arguments

* Fix compilation error of tuning parameter

* In example, use same setting as unit-test

* Remove no-longer used include directive

* Add interface for grouped conv bwd weight

* Add group support for conv bwd weight

* Add grouped conv bwd weight example

* Use group parameter in example

* Rename example folder

* Remove non-grouped version example source files

* Rename device op template

* Add group support to convolution backward weight

* Remove debug messages

* Use smaller group size in example

* Use named variable as loop terminate condition

* Prettify example output message

* Enlarge used grid size

* Allow real grid size exceeds expected grid size

* Rename interface file

* Add client example for grouped conv2d bwd weight

* Fix wrong include directive

* Rename client example folder
parent 67423a22
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using GNDHWC = ck::tensor_layout::convolution::GNDHWC;
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvBwdWeightDefault =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
// Compilation parameters for in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k]
using device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_f32_default_instances =
std::tuple<
// clang-format off
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
// clang-format on
>;
using device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_f32_instances =
std::tuple<
// clang-format off
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
// clang-format on
>;
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_f32_default_instances{});
add_device_operation_instances(
instances,
device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_f32_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -20,8 +20,8 @@ set(PROFILER_SOURCE ...@@ -20,8 +20,8 @@ set(PROFILER_SOURCE
src/profile_conv_fwd_bias_relu.cpp src/profile_conv_fwd_bias_relu.cpp
src/profile_conv_fwd_bias_relu_add.cpp src/profile_conv_fwd_bias_relu_add.cpp
src/profile_conv_bwd_data.cpp src/profile_conv_bwd_data.cpp
src/profile_conv_bwd_weight.cpp
src/profile_grouped_conv_fwd.cpp src/profile_grouped_conv_fwd.cpp
src/profile_grouped_conv_bwd_weight.cpp
src/profile_reduce.cpp src/profile_reduce.cpp
src/profile_groupnorm.cpp src/profile_groupnorm.cpp
src/profile_layernorm.cpp src/profile_layernorm.cpp
...@@ -49,9 +49,9 @@ target_link_libraries(ckProfiler PRIVATE device_grouped_conv3d_fwd_instance) ...@@ -49,9 +49,9 @@ target_link_libraries(ckProfiler PRIVATE device_grouped_conv3d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv1d_bwd_data_instance) target_link_libraries(ckProfiler PRIVATE device_conv1d_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_data_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_conv3d_bwd_data_instance) target_link_libraries(ckProfiler PRIVATE device_conv3d_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_conv1d_bwd_weight_instance) target_link_libraries(ckProfiler PRIVATE device_grouped_conv1d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance) target_link_libraries(ckProfiler PRIVATE device_grouped_conv2d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_conv3d_bwd_weight_instance) target_link_libraries(ckProfiler PRIVATE device_grouped_conv3d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(ckProfiler PRIVATE device_normalization_instance) target_link_libraries(ckProfiler PRIVATE device_normalization_instance)
......
...@@ -3,9 +3,10 @@ ...@@ -3,9 +3,10 @@
#pragma once #pragma once
#include "ck/ck.hpp" #include <algorithm>
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
#include <iterator>
#include <typeinfo> #include <typeinfo>
#include "ck/ck.hpp" #include "ck/ck.hpp"
...@@ -13,7 +14,7 @@ ...@@ -13,7 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" #include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/convolution_backward_weight.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -26,32 +27,6 @@ ...@@ -26,32 +27,6 @@
namespace ck { namespace ck {
namespace profiler { namespace profiler {
template <typename DataType>
void show_data_nhwc_layout(Tensor<DataType>& nhwc)
{
std::cout << "[";
for(int n = 0; n < ck::type_convert<int>(nhwc.mDesc.GetLengths()[0]); n++)
{
std::cout << "[";
for(int hi = 0; hi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[2]); hi++)
{
std::cout << "[";
for(int wi = 0; wi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[3]); wi++)
{
std::cout << "[";
for(int c = 0; c < ck::type_convert<int>(nhwc.mDesc.GetLengths()[1]); c++)
{
std::cout << static_cast<float>(nhwc(n, c, hi, wi)) << " ";
}
std::cout << "]";
}
std::cout << "]";
}
std::cout << "]";
}
std::cout << "]";
}
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InLayout, typename InLayout,
typename WeiLayout, typename WeiLayout,
...@@ -59,12 +34,12 @@ template <ck::index_t NDimSpatial, ...@@ -59,12 +34,12 @@ template <ck::index_t NDimSpatial,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType> typename OutDataType>
bool profile_conv_bwd_weight_impl(int do_verification, bool profile_grouped_conv_bwd_weight_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
bool time_kernel, bool time_kernel,
const ck::utils::conv::ConvParam& conv_param, const ck::utils::conv::ConvParam& conv_param,
ck::index_t split_k) ck::index_t split_k)
{ {
using InElementOp = ck::tensor_operation::element_wise::PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
...@@ -114,16 +89,14 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -114,16 +89,14 @@ bool profile_conv_bwd_weight_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial, auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp>{}; OutElementOp>{};
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(input, auto ref_argument = ref_conv.MakeArgument(input,
weight_host_result, weight_host_result,
output, output,
...@@ -138,16 +111,16 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -138,16 +111,16 @@ bool profile_conv_bwd_weight_impl(int do_verification,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
using DeviceOp = ck::tensor_operation::device::DeviceConvBwdWeight<NDimSpatial, using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout, OutLayout,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp>; OutElementOp>;
// get device op instances // get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
...@@ -163,22 +136,41 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -163,22 +136,41 @@ bool profile_conv_bwd_weight_impl(int do_verification,
// profile device Conv instances // profile device Conv instances
bool all_pass = true; bool all_pass = true;
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); };
range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths));
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
range_copy(conv_param.input_right_pads_, begin(input_right_pads));
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr = auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.G_,
conv_param.N_, conv_param.N_,
conv_param.K_, conv_param.K_,
conv_param.C_, conv_param.C_,
conv_param.input_spatial_lengths_, input_spatial_lengths,
conv_param.filter_spatial_lengths_, filter_spatial_lengths,
conv_param.output_spatial_lengths_, output_spatial_lengths,
conv_param.conv_filter_strides_, conv_filter_strides,
conv_param.conv_filter_dilations_, conv_filter_dilations,
conv_param.input_left_pads_, input_left_pads,
conv_param.input_right_pads_, input_right_pads,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op, out_element_op,
...@@ -218,32 +210,29 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -218,32 +210,29 @@ bool profile_conv_bwd_weight_impl(int do_verification,
wei_device_buf.FromDevice(weight_device_result.mData.data()); wei_device_buf.FromDevice(weight_device_result.mData.data());
bool pass = bool pass =
ck::utils::check_err(weight_host_result.mData, weight_device_result.mData); ck::utils::check_err(weight_device_result.mData, weight_host_result.mData);
if(!pass) if(!pass)
{ {
std::cout << "Fail info:" << op_ptr->GetTypeString() << std::endl; std::cout << "Fail info: " << op_ptr->GetTypeString() << std::endl;
} }
all_pass &= pass; all_pass &= pass;
if(do_log) if(do_log)
{ {
std::cout << "in : "; LogRangeAsType<float>(std::cout << "output : ", output.mData, ",") << std::endl;
show_data_nhwc_layout(output); ;
std::cout << std::endl; LogRangeAsType<float>(
std::cout << "weight (device): ", weight_device_result.mData, ",")
std::cout << "wei: "; << std::endl;
show_data_nhwc_layout(weight_host_result); ;
std::cout << std::endl; LogRangeAsType<float>(
std::cout << "weight (host): ", weight_host_result.mData, ",")
std::cout << "out : "; << std::endl;
show_data_nhwc_layout(input); ;
std::cout << std::endl; LogRangeAsType<float>(std::cout << "input: ", input.mData, ",") << std::endl;
;
std::cout << "wei_device: ";
show_data_nhwc_layout(weight_device_result);
std::cout << std::endl;
} }
} }
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/include/profile_conv_bwd_weight_impl.hpp" #include "profiler/include/profile_grouped_conv_bwd_weight_impl.hpp"
namespace { namespace {
enum struct ConvLayout enum struct ConvLayout
{ {
NCHW_KCYX_NKHW, // 0 GNCHW_GKCYX_GNKHW, // 0
NHWC_KYXC_NHWK, // 1 GNHWC_GKYXC_GNHWK, // 1
}; };
enum struct ConvDataType enum struct ConvDataType
...@@ -25,24 +25,25 @@ enum struct ConvDataType ...@@ -25,24 +25,25 @@ enum struct ConvDataType
static void print_helper_msg() static void print_helper_msg()
{ {
std::cout std::cout << "arg1: tensor operation (conv_bwd_weight: Convolution Backward Weight\n"
<< "arg1: tensor operation (conv_bwd_weight: Convolution Backward Weight\n" << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n"
<< "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" << " 1: Input fp16, Weight fp16, Output fp16\n"
<< " 1: Input fp16, Weight fp16, Output fp16\n" << " 2: Input bf16, Weight fp32, Output bf16)\n"
<< " 2: Input bf16, Weight fp32, Output bf16)\n" << "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, "
<< "arg3: tensor layout (0: Input[N, C, Hi, Wi], Weight[K, C, Y, X], Output[N, K, Ho, Wo]\n" "N, K, Ho, Wo]\n"
<< " 1: Input[N, Hi, Wi, C], Weight[K, Y, X, C], Output[N, Ho, Wo, K]\n" << " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, "
<< "arg4: verification (0: no, 1: yes)\n" "N, Ho, Wo, K]\n"
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" << "arg4: verification (0: no, 1: yes)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n" << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg7: time kernel (0: no, 1: yes)\n" << "arg6: print tensor value (0: no; 1: yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << " SplitK\n" << "arg7: time kernel (0: no, 1: yes)\n"
<< std::endl; << ck::utils::conv::get_conv_param_parser_helper_msg() << " SplitK\n"
<< std::endl;
} }
} // namespace } // namespace
int profile_conv_bwd_weight(int argc, char* argv[]) int profile_grouped_conv_bwd_weight(int argc, char* argv[])
{ {
// 8 for control, 1 for num_dim_spatial // 8 for control, 1 for num_dim_spatial
if(argc < 9) if(argc < 9)
...@@ -75,17 +76,17 @@ int profile_conv_bwd_weight(int argc, char* argv[]) ...@@ -75,17 +76,17 @@ int profile_conv_bwd_weight(int argc, char* argv[])
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using NWC = ck::tensor_layout::convolution::NWC; using GNWC = ck::tensor_layout::convolution::GNWC;
using NHWC = ck::tensor_layout::convolution::NHWC; using GNHWC = ck::tensor_layout::convolution::GNHWC;
using NDHWC = ck::tensor_layout::convolution::NDHWC; using GNDHWC = ck::tensor_layout::convolution::GNDHWC;
using KXC = ck::tensor_layout::convolution::KXC; using GKXC = ck::tensor_layout::convolution::GKXC;
using KYXC = ck::tensor_layout::convolution::KYXC; using GKYXC = ck::tensor_layout::convolution::GKYXC;
using KZYXC = ck::tensor_layout::convolution::KZYXC; using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
using NWK = ck::tensor_layout::convolution::NWK; using GNWK = ck::tensor_layout::convolution::GNWK;
using NHWK = ck::tensor_layout::convolution::NHWK; using GNHWK = ck::tensor_layout::convolution::GNHWK;
using NDHWK = ck::tensor_layout::convolution::NDHWK; using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
constexpr auto I1 = ck::Number<1>{}; constexpr auto I1 = ck::Number<1>{};
constexpr auto I2 = ck::Number<2>{}; constexpr auto I2 = ck::Number<2>{};
...@@ -108,64 +109,64 @@ int profile_conv_bwd_weight(int argc, char* argv[]) ...@@ -108,64 +109,64 @@ int profile_conv_bwd_weight(int argc, char* argv[])
using WeiDataType = decltype(wei_type); using WeiDataType = decltype(wei_type);
using OutDataType = decltype(out_type); using OutDataType = decltype(out_type);
bool pass = ck::profiler::profile_conv_bwd_weight_impl<NDimSpatial, bool pass = ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout, OutLayout,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType>( OutDataType>(
do_verification, init_method, do_log, time_kernel, params, split_k); do_verification, init_method, do_log, time_kernel, params, split_k);
return pass ? 0 : 1; return pass ? 0 : 1;
}; };
if(num_dim_spatial == 1 && layout == ConvLayout::NHWC_KYXC_NHWK) if(num_dim_spatial == 1 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I1, NWC{}, KXC{}, NWK{}, F32{}, F32{}, F32{}); return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) else if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I1, NWC{}, KXC{}, NWK{}, F16{}, F16{}, F16{}); return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) else if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I1, NWC{}, KXC{}, NWK{}, BF16{}, F32{}, BF16{}); return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{});
} }
} }
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWC_KYXC_NHWK) else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I2, NHWC{}, KYXC{}, NHWK{}, F32{}, F32{}, F32{}); return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) else if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I2, NHWC{}, KYXC{}, NHWK{}, F16{}, F16{}, F16{}); return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) else if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I2, NHWC{}, KYXC{}, NHWK{}, BF16{}, F32{}, BF16{}); return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{});
} }
} }
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWC_KYXC_NHWK) else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, F32{}, F32{}, F32{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) else if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, F16{}, F16{}, F16{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) else if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, BF16{}, F32{}, BF16{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{});
} }
} }
......
This diff is collapsed.
...@@ -45,9 +45,9 @@ add_subdirectory(batched_gemm_softmax_gemm_permute) ...@@ -45,9 +45,9 @@ add_subdirectory(batched_gemm_softmax_gemm_permute)
add_subdirectory(grouped_gemm) add_subdirectory(grouped_gemm)
add_subdirectory(reduce) add_subdirectory(reduce)
add_subdirectory(convnd_fwd) add_subdirectory(convnd_fwd)
add_subdirectory(convnd_bwd_weight)
add_subdirectory(convnd_bwd_data) add_subdirectory(convnd_bwd_data)
add_subdirectory(grouped_convnd_fwd) add_subdirectory(grouped_convnd_fwd)
add_subdirectory(grouped_convnd_bwd_weight)
add_subdirectory(block_to_ctile_map) add_subdirectory(block_to_ctile_map)
add_subdirectory(softmax) add_subdirectory(softmax)
add_subdirectory(normalization) add_subdirectory(normalization)
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment