Unverified Commit 38470e04 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Add client example of grouped conv2d backward weight (data type: fp16) (#498)

* Remove redundant CMake setting

* Extract common code from files

* Rename folder 'convnd' to 'conv'

* Use std::array<> to accept compile-time kwnown # of arguments

* Fix compilation error of tuning parameter

* In example, use same setting as unit-test

* Remove no-longer used include directive

* Add interface for grouped conv bwd weight

* Add group support for conv bwd weight

* Add grouped conv bwd weight example

* Use group parameter in example

* Rename example folder

* Remove non-grouped version example source files

* Rename device op template

* Add group support to convolution backward weight

* Remove debug messages

* Use smaller group size in example

* Use named variable as loop terminate condition

* Prettify example output message

* Enlarge used grid size

* Allow real grid size exceeds expected grid size

* Rename interface file

* Add client example for grouped conv2d bwd weight

* Fix wrong include directive

* Rename client example folder
parent 67423a22
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using GNDHWC = ck::tensor_layout::convolution::GNDHWC;
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvBwdWeightDefault =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
// Compilation parameters for in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k]
using device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_f32_default_instances =
std::tuple<
// clang-format off
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
// clang-format on
>;
using device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_f32_instances =
std::tuple<
// clang-format off
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
// clang-format on
>;
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_f32_default_instances{});
add_device_operation_instances(
instances,
device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_f32_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -20,8 +20,8 @@ set(PROFILER_SOURCE ...@@ -20,8 +20,8 @@ set(PROFILER_SOURCE
src/profile_conv_fwd_bias_relu.cpp src/profile_conv_fwd_bias_relu.cpp
src/profile_conv_fwd_bias_relu_add.cpp src/profile_conv_fwd_bias_relu_add.cpp
src/profile_conv_bwd_data.cpp src/profile_conv_bwd_data.cpp
src/profile_conv_bwd_weight.cpp
src/profile_grouped_conv_fwd.cpp src/profile_grouped_conv_fwd.cpp
src/profile_grouped_conv_bwd_weight.cpp
src/profile_reduce.cpp src/profile_reduce.cpp
src/profile_groupnorm.cpp src/profile_groupnorm.cpp
src/profile_layernorm.cpp src/profile_layernorm.cpp
...@@ -49,9 +49,9 @@ target_link_libraries(ckProfiler PRIVATE device_grouped_conv3d_fwd_instance) ...@@ -49,9 +49,9 @@ target_link_libraries(ckProfiler PRIVATE device_grouped_conv3d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv1d_bwd_data_instance) target_link_libraries(ckProfiler PRIVATE device_conv1d_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_data_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_conv3d_bwd_data_instance) target_link_libraries(ckProfiler PRIVATE device_conv3d_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_conv1d_bwd_weight_instance) target_link_libraries(ckProfiler PRIVATE device_grouped_conv1d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance) target_link_libraries(ckProfiler PRIVATE device_grouped_conv2d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_conv3d_bwd_weight_instance) target_link_libraries(ckProfiler PRIVATE device_grouped_conv3d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(ckProfiler PRIVATE device_normalization_instance) target_link_libraries(ckProfiler PRIVATE device_normalization_instance)
......
...@@ -3,9 +3,10 @@ ...@@ -3,9 +3,10 @@
#pragma once #pragma once
#include "ck/ck.hpp" #include <algorithm>
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
#include <iterator>
#include <typeinfo> #include <typeinfo>
#include "ck/ck.hpp" #include "ck/ck.hpp"
...@@ -13,7 +14,7 @@ ...@@ -13,7 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" #include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/convolution_backward_weight.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -26,32 +27,6 @@ ...@@ -26,32 +27,6 @@
namespace ck { namespace ck {
namespace profiler { namespace profiler {
template <typename DataType>
void show_data_nhwc_layout(Tensor<DataType>& nhwc)
{
std::cout << "[";
for(int n = 0; n < ck::type_convert<int>(nhwc.mDesc.GetLengths()[0]); n++)
{
std::cout << "[";
for(int hi = 0; hi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[2]); hi++)
{
std::cout << "[";
for(int wi = 0; wi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[3]); wi++)
{
std::cout << "[";
for(int c = 0; c < ck::type_convert<int>(nhwc.mDesc.GetLengths()[1]); c++)
{
std::cout << static_cast<float>(nhwc(n, c, hi, wi)) << " ";
}
std::cout << "]";
}
std::cout << "]";
}
std::cout << "]";
}
std::cout << "]";
}
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InLayout, typename InLayout,
typename WeiLayout, typename WeiLayout,
...@@ -59,12 +34,12 @@ template <ck::index_t NDimSpatial, ...@@ -59,12 +34,12 @@ template <ck::index_t NDimSpatial,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType> typename OutDataType>
bool profile_conv_bwd_weight_impl(int do_verification, bool profile_grouped_conv_bwd_weight_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
bool time_kernel, bool time_kernel,
const ck::utils::conv::ConvParam& conv_param, const ck::utils::conv::ConvParam& conv_param,
ck::index_t split_k) ck::index_t split_k)
{ {
using InElementOp = ck::tensor_operation::element_wise::PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
...@@ -114,16 +89,14 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -114,16 +89,14 @@ bool profile_conv_bwd_weight_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial, auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp>{}; OutElementOp>{};
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(input, auto ref_argument = ref_conv.MakeArgument(input,
weight_host_result, weight_host_result,
output, output,
...@@ -138,16 +111,16 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -138,16 +111,16 @@ bool profile_conv_bwd_weight_impl(int do_verification,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
using DeviceOp = ck::tensor_operation::device::DeviceConvBwdWeight<NDimSpatial, using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout, OutLayout,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp>; OutElementOp>;
// get device op instances // get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
...@@ -163,22 +136,41 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -163,22 +136,41 @@ bool profile_conv_bwd_weight_impl(int do_verification,
// profile device Conv instances // profile device Conv instances
bool all_pass = true; bool all_pass = true;
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); };
range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths));
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
range_copy(conv_param.input_right_pads_, begin(input_right_pads));
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr = auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.G_,
conv_param.N_, conv_param.N_,
conv_param.K_, conv_param.K_,
conv_param.C_, conv_param.C_,
conv_param.input_spatial_lengths_, input_spatial_lengths,
conv_param.filter_spatial_lengths_, filter_spatial_lengths,
conv_param.output_spatial_lengths_, output_spatial_lengths,
conv_param.conv_filter_strides_, conv_filter_strides,
conv_param.conv_filter_dilations_, conv_filter_dilations,
conv_param.input_left_pads_, input_left_pads,
conv_param.input_right_pads_, input_right_pads,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op, out_element_op,
...@@ -218,32 +210,29 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -218,32 +210,29 @@ bool profile_conv_bwd_weight_impl(int do_verification,
wei_device_buf.FromDevice(weight_device_result.mData.data()); wei_device_buf.FromDevice(weight_device_result.mData.data());
bool pass = bool pass =
ck::utils::check_err(weight_host_result.mData, weight_device_result.mData); ck::utils::check_err(weight_device_result.mData, weight_host_result.mData);
if(!pass) if(!pass)
{ {
std::cout << "Fail info:" << op_ptr->GetTypeString() << std::endl; std::cout << "Fail info: " << op_ptr->GetTypeString() << std::endl;
} }
all_pass &= pass; all_pass &= pass;
if(do_log) if(do_log)
{ {
std::cout << "in : "; LogRangeAsType<float>(std::cout << "output : ", output.mData, ",") << std::endl;
show_data_nhwc_layout(output); ;
std::cout << std::endl; LogRangeAsType<float>(
std::cout << "weight (device): ", weight_device_result.mData, ",")
std::cout << "wei: "; << std::endl;
show_data_nhwc_layout(weight_host_result); ;
std::cout << std::endl; LogRangeAsType<float>(
std::cout << "weight (host): ", weight_host_result.mData, ",")
std::cout << "out : "; << std::endl;
show_data_nhwc_layout(input); ;
std::cout << std::endl; LogRangeAsType<float>(std::cout << "input: ", input.mData, ",") << std::endl;
;
std::cout << "wei_device: ";
show_data_nhwc_layout(weight_device_result);
std::cout << std::endl;
} }
} }
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/include/profile_conv_bwd_weight_impl.hpp" #include "profiler/include/profile_grouped_conv_bwd_weight_impl.hpp"
namespace { namespace {
enum struct ConvLayout enum struct ConvLayout
{ {
NCHW_KCYX_NKHW, // 0 GNCHW_GKCYX_GNKHW, // 0
NHWC_KYXC_NHWK, // 1 GNHWC_GKYXC_GNHWK, // 1
}; };
enum struct ConvDataType enum struct ConvDataType
...@@ -25,24 +25,25 @@ enum struct ConvDataType ...@@ -25,24 +25,25 @@ enum struct ConvDataType
static void print_helper_msg() static void print_helper_msg()
{ {
std::cout std::cout << "arg1: tensor operation (conv_bwd_weight: Convolution Backward Weight\n"
<< "arg1: tensor operation (conv_bwd_weight: Convolution Backward Weight\n" << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n"
<< "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" << " 1: Input fp16, Weight fp16, Output fp16\n"
<< " 1: Input fp16, Weight fp16, Output fp16\n" << " 2: Input bf16, Weight fp32, Output bf16)\n"
<< " 2: Input bf16, Weight fp32, Output bf16)\n" << "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, "
<< "arg3: tensor layout (0: Input[N, C, Hi, Wi], Weight[K, C, Y, X], Output[N, K, Ho, Wo]\n" "N, K, Ho, Wo]\n"
<< " 1: Input[N, Hi, Wi, C], Weight[K, Y, X, C], Output[N, Ho, Wo, K]\n" << " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, "
<< "arg4: verification (0: no, 1: yes)\n" "N, Ho, Wo, K]\n"
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" << "arg4: verification (0: no, 1: yes)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n" << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg7: time kernel (0: no, 1: yes)\n" << "arg6: print tensor value (0: no; 1: yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << " SplitK\n" << "arg7: time kernel (0: no, 1: yes)\n"
<< std::endl; << ck::utils::conv::get_conv_param_parser_helper_msg() << " SplitK\n"
<< std::endl;
} }
} // namespace } // namespace
int profile_conv_bwd_weight(int argc, char* argv[]) int profile_grouped_conv_bwd_weight(int argc, char* argv[])
{ {
// 8 for control, 1 for num_dim_spatial // 8 for control, 1 for num_dim_spatial
if(argc < 9) if(argc < 9)
...@@ -75,17 +76,17 @@ int profile_conv_bwd_weight(int argc, char* argv[]) ...@@ -75,17 +76,17 @@ int profile_conv_bwd_weight(int argc, char* argv[])
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using NWC = ck::tensor_layout::convolution::NWC; using GNWC = ck::tensor_layout::convolution::GNWC;
using NHWC = ck::tensor_layout::convolution::NHWC; using GNHWC = ck::tensor_layout::convolution::GNHWC;
using NDHWC = ck::tensor_layout::convolution::NDHWC; using GNDHWC = ck::tensor_layout::convolution::GNDHWC;
using KXC = ck::tensor_layout::convolution::KXC; using GKXC = ck::tensor_layout::convolution::GKXC;
using KYXC = ck::tensor_layout::convolution::KYXC; using GKYXC = ck::tensor_layout::convolution::GKYXC;
using KZYXC = ck::tensor_layout::convolution::KZYXC; using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
using NWK = ck::tensor_layout::convolution::NWK; using GNWK = ck::tensor_layout::convolution::GNWK;
using NHWK = ck::tensor_layout::convolution::NHWK; using GNHWK = ck::tensor_layout::convolution::GNHWK;
using NDHWK = ck::tensor_layout::convolution::NDHWK; using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
constexpr auto I1 = ck::Number<1>{}; constexpr auto I1 = ck::Number<1>{};
constexpr auto I2 = ck::Number<2>{}; constexpr auto I2 = ck::Number<2>{};
...@@ -108,64 +109,64 @@ int profile_conv_bwd_weight(int argc, char* argv[]) ...@@ -108,64 +109,64 @@ int profile_conv_bwd_weight(int argc, char* argv[])
using WeiDataType = decltype(wei_type); using WeiDataType = decltype(wei_type);
using OutDataType = decltype(out_type); using OutDataType = decltype(out_type);
bool pass = ck::profiler::profile_conv_bwd_weight_impl<NDimSpatial, bool pass = ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout, OutLayout,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType>( OutDataType>(
do_verification, init_method, do_log, time_kernel, params, split_k); do_verification, init_method, do_log, time_kernel, params, split_k);
return pass ? 0 : 1; return pass ? 0 : 1;
}; };
if(num_dim_spatial == 1 && layout == ConvLayout::NHWC_KYXC_NHWK) if(num_dim_spatial == 1 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I1, NWC{}, KXC{}, NWK{}, F32{}, F32{}, F32{}); return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) else if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I1, NWC{}, KXC{}, NWK{}, F16{}, F16{}, F16{}); return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) else if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I1, NWC{}, KXC{}, NWK{}, BF16{}, F32{}, BF16{}); return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{});
} }
} }
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWC_KYXC_NHWK) else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I2, NHWC{}, KYXC{}, NHWK{}, F32{}, F32{}, F32{}); return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) else if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I2, NHWC{}, KYXC{}, NHWK{}, F16{}, F16{}, F16{}); return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) else if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I2, NHWC{}, KYXC{}, NHWK{}, BF16{}, F32{}, BF16{}); return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{});
} }
} }
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWC_KYXC_NHWK) else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, F32{}, F32{}, F32{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) else if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, F16{}, F16{}, F16{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) else if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, BF16{}, F32{}, BF16{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{});
} }
} }
......
This diff is collapsed.
...@@ -45,9 +45,9 @@ add_subdirectory(batched_gemm_softmax_gemm_permute) ...@@ -45,9 +45,9 @@ add_subdirectory(batched_gemm_softmax_gemm_permute)
add_subdirectory(grouped_gemm) add_subdirectory(grouped_gemm)
add_subdirectory(reduce) add_subdirectory(reduce)
add_subdirectory(convnd_fwd) add_subdirectory(convnd_fwd)
add_subdirectory(convnd_bwd_weight)
add_subdirectory(convnd_bwd_data) add_subdirectory(convnd_bwd_data)
add_subdirectory(grouped_convnd_fwd) add_subdirectory(grouped_convnd_fwd)
add_subdirectory(grouped_convnd_bwd_weight)
add_subdirectory(block_to_ctile_map) add_subdirectory(block_to_ctile_map)
add_subdirectory(softmax) add_subdirectory(softmax)
add_subdirectory(normalization) add_subdirectory(normalization)
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment