Commit 494a8fa4 authored by rocking's avatar rocking
Browse files

Add the full example of avgpool bwd

parent ffd7913b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp"
using DOutDataType = float;
using DInDataType = float;
using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
......@@ -38,6 +33,17 @@ bool pool3d_bwd_test(bool do_verification,
std::vector<ck::index_t> dinput_left_pads,
std::vector<ck::index_t> dinput_right_pads)
{
using DevicePoolBwdInstance =
ck::tensor_operation::device::DeviceAvgPool3dBwdImpl<DOutDataType,
DInDataType,
ComputeDataType, // ComputeDataType
64, // BlockSize
64, // ReduceMThreadClusterSize
1, // ReduceKThreadClusterSize
1, // ReduceMThreadSliceSize
1, // ReduceKThreadSliceSize
1>; // InSrcOutDstVectorSize
auto OutSpatialLength = [&](auto InSpatialLength, int index) {
ck::index_t left_pad = dinput_left_pads[index];
ck::index_t right_pad = dinput_right_pads[index];
......@@ -50,15 +56,53 @@ bool pool3d_bwd_test(bool do_verification,
ck::index_t Ho = OutSpatialLength(Hi, 1);
ck::index_t Wo = OutSpatialLength(Wi, 1);
Tensor<DOutDataType> dout(HostTensorDescriptor({N, C, Do, Ho, Wo}));
Tensor<DInDataType> din_dev(HostTensorDescriptor({N, C, Di, Hi, Wi}));
Tensor<DInDataType> din_host(HostTensorDescriptor({N, C, Di, Hi, Wi}));
auto f_host_tensor_descriptor =
[](std::size_t N_, std::size_t C_, std::size_t D, std::size_t H, std::size_t W) {
using namespace ck::literals;
return HostTensorDescriptor({N_, C_, D, H, W},
{D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_});
};
Tensor<DOutDataType> dout(f_host_tensor_descriptor(N, C, Do, Ho, Wo));
Tensor<DInDataType> din_dev(f_host_tensor_descriptor(N, C, Di, Hi, Wi));
Tensor<DInDataType> din_host(f_host_tensor_descriptor(N, C, Di, Hi, Wi));
std::cout << "dout: " << dout.mDesc << std::endl;
std::cout << "din_host: " << din_host.mDesc << std::endl;
dout.GenerateTensorValue(GeneratorTensor_3<DOutDataType>{0.0, 1.0});
DeviceMem dout_device_buf(sizeof(DOutDataType) * dout.mDesc.GetElementSpaceSize());
DeviceMem din_device_buf(sizeof(DInDataType) * din_dev.mDesc.GetElementSpaceSize());
dout_device_buf.ToDevice(dout.mData.data());
auto pool = DevicePoolBwdInstance{};
auto invoker_ptr = pool.MakeInvokerPointer();
auto argument_ptr =
pool.MakeArgumentPointer(static_cast<DOutDataType*>(dout_device_buf.GetDeviceBuffer()),
static_cast<DInDataType*>(din_device_buf.GetDeviceBuffer()),
{N, C, Do, Ho, Wo},
{N, C, Di, Hi, Wi},
window_lengths,
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C},
{Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C},
window_strides,
window_dilations,
dinput_left_pads,
dinput_right_pads);
if(!pool.IsSupportedArgument(argument_ptr.get()))
{
throw std::runtime_error("wrong! device_op with the specified compilation parameters does "
"not support this problem");
}
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::cout << "Perf: " << ave_time << std::endl;
bool pass = true;
if(do_verification)
{
auto ref_pool =
......@@ -75,11 +119,12 @@ bool pool3d_bwd_test(bool do_verification,
dinput_right_pads);
ref_invoker.Run(ref_argument);
din_device_buf.FromDevice(din_dev.mData.data());
pass = ck::utils::check_err(din_dev, din_host);
}
// TODO - full example
ck::ignore = time_kernel;
return 0;
return pass;
}
int main()
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/utility/reduction_enums.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DOutDataType, typename DInDataType>
struct DeviceAvgPoolBwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_dout,
void* p_din,
std::vector<ck::index_t> dout_n_k_wos_lengths,
std::vector<ck::index_t> dout_n_k_wos_strides,
std::vector<ck::index_t> din_n_k_wos_length,
std::vector<ck::index_t> din_n_k_wos_strides,
std::vector<ck::index_t> window_k_c_xs_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DOutDataType,
typename DInDataType,
typename ComputeDataType,
ck::index_t BlockSize,
ck::index_t MThreadClusterSize,
ck::index_t KThreadClusterSize,
ck::index_t MThreadSliceSize,
ck::index_t KThreadSliceSize,
ck::index_t InSrcOutDstVectorSize>
struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataType>
{
struct Argument : public BaseArgument
{
Argument() {}
};
struct Invoker : public BaseInvoker
{
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
ignore = p_arg;
ignore = stream_config;
return 0;
}
};
static bool IsSupportedArgument(const Argument& arg)
{
ignore = arg;
return true;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_dout,
void* p_din,
std::vector<ck::index_t> dout_n_k_wos_lengths,
std::vector<ck::index_t> din_n_k_wos_length,
std::vector<ck::index_t> window_k_c_xs_lengths,
std::vector<ck::index_t> dout_n_k_wos_strides,
std::vector<ck::index_t> din_n_k_wos_strides,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) override
{
ignore = p_dout;
ignore = p_din;
ignore = dout_n_k_wos_lengths;
ignore = dout_n_k_wos_strides;
ignore = din_n_k_wos_length;
ignore = din_n_k_wos_strides;
ignore = window_k_c_xs_lengths;
ignore = window_strides;
ignore = window_dilations;
ignore = input_left_pads;
ignore = input_right_pads;
return std::make_unique<Argument>();
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceAvgPool3dBwd<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment