Commit acd980fc authored by rocking's avatar rocking
Browse files

Add device pool bwd device op

parent afe8dae6
......@@ -10,7 +10,7 @@
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -29,9 +29,7 @@ template <typename InDataType,
typename DOutDataType,
typename InLayout,
typename OutLayout,
ck::ReduceTensorOp ReduceOpId,
bool PropagateNan,
ck::InMemoryDataOperationEnum Memop>
bool PropagateNan>
bool maxpool_bwd_test(bool do_verification,
bool time_kernel,
ck::index_t N,
......@@ -55,7 +53,7 @@ bool maxpool_bwd_test(bool do_verification,
OutDataType, // OutDataType
IndexDataType, // IndexDataType
ComputeDataType, // ComputeDataType
ReduceOpId,
ck::ReduceTensorOp::MAX,
true, // OutputIndex
64, // BlockSize
64, // ReduceMThreadClusterSize
......@@ -65,7 +63,7 @@ bool maxpool_bwd_test(bool do_verification,
1>; // InSrcOutDstVectorSize
using DeviceMaxPoolBwdInstance = ck::tensor_operation::device::
DevicePutElementImpl<DOutDataType, IndexDataType, DInDataType, PassThrough, Memop, 4>;
DeviceIndexPoolBwdImpl<DOutDataType, IndexDataType, DInDataType, 4>;
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1;
......@@ -140,7 +138,7 @@ bool maxpool_bwd_test(bool do_verification,
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(indices_device_buf.GetDeviceBuffer()),
{N, C, Hi, Wi},
{Y, X},
window_spatial_lengths,
{N, C, Ho, Wo},
{C * Hi * Wi, 1, Wi * C, C},
{C * Ho * Wo, 1, Wo * C, C},
......@@ -167,7 +165,8 @@ bool maxpool_bwd_test(bool do_verification,
static_cast<DInDataType*>(din_device_buf.GetDeviceBuffer()),
dout_n_c_ho_wo.mDesc.GetElementSpaceSize(),
din_n_c_hi_wi_device.mDesc.GetElementSpaceSize(),
PassThrough{});
window_spatial_lengths,
window_strides);
if(!pool_bwd.IsSupportedArgument(pool_bwd_argument_ptr.get()))
{
......@@ -192,7 +191,7 @@ bool maxpool_bwd_test(bool do_verification,
OutDataType,
ComputeDataType,
IndexDataType,
ReduceOpId,
ck::ReduceTensorOp::MAX,
PropagateNan,
true>;
......
......@@ -27,24 +27,18 @@ int main()
bool time_kernel = false;
// Pool shape
constexpr ck::index_t N = 1;
constexpr ck::index_t C = 1;
constexpr ck::index_t Y = 2;
constexpr ck::index_t X = 2;
constexpr ck::index_t Hi = 31;
constexpr ck::index_t Wi = 31;
constexpr ck::index_t window_stride_h = 2;
constexpr ck::index_t window_stride_w = 2;
constexpr ck::index_t in_left_pad_h = 0;
constexpr ck::index_t in_left_pad_w = 0;
constexpr ck::index_t in_right_pad_h = 1;
constexpr ck::index_t in_right_pad_w = 1;
constexpr bool WindowOverlap = Y > window_stride_h || X > window_stride_w;
constexpr ck::InMemoryDataOperationEnum MemOp = WindowOverlap
? ck::InMemoryDataOperationEnum::AtomicAdd
: ck::InMemoryDataOperationEnum::Set;
std::cout << "WindowOverlap = " << WindowOverlap << std::endl;
ck::index_t N = 1;
ck::index_t C = 1;
ck::index_t Y = 2;
ck::index_t X = 2;
ck::index_t Hi = 31;
ck::index_t Wi = 31;
ck::index_t window_stride_h = 2;
ck::index_t window_stride_w = 2;
ck::index_t in_left_pad_h = 0;
ck::index_t in_left_pad_w = 0;
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
bool pass = maxpool_bwd_test<InDataType,
OutDataType,
......@@ -54,9 +48,7 @@ int main()
DOutDataType,
InLayout,
OutLayout,
ck::ReduceTensorOp::MAX,
PropagateNan,
MemOp>(do_verification,
PropagateNan>(do_verification,
time_kernel,
N,
C,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/utility/reduction_enums.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// For pooling which used indexable operation, such as MaxPool, MinPool...etc
template <typename DOutDataType, typename IndexDataType, typename DInDataType>
struct DeviceIndexPoolBwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_dout,
const void* p_indices,
void* p_din,
index_t dout_length,
index_t din_length,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> window_strides) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// output[indices] = input
template <typename DOutDataType,
typename IndexDataType,
typename DInDataType,
ck::index_t InVectorSize>
struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDataType, DInDataType>
{
static_assert(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>,
"Data type is not supported!");
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
template <typename Desc_M>
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
{
constexpr auto I0 = Number<0>{};
const auto m = desc_m.GetLength(I0);
const index_t loop_step = gridSize * blockSize * InVectorSize;
const auto pad = math::integer_least_multiple(m, loop_step) - m;
const auto desc_m_pad =
transform_tensor_descriptor(desc_m,
make_tuple(make_right_pad_transform(m, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m_pad;
}
static auto MakeDescriptor_M(index_t length, index_t gridSize, index_t blockSize)
{
const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length));
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
}
using OutGrid1dDesc = decltype(MakeDescriptor_M(1, 1, 1));
using GridwisePutElementSet = GridwisePutElement_1D<OutGrid1dDesc,
DOutDataType,
IndexDataType,
DInDataType,
PassThrough,
InMemoryDataOperationEnum::Set,
InVectorSize>;
using GridwisePutElementAtomicAdd = GridwisePutElement_1D<OutGrid1dDesc,
DOutDataType,
IndexDataType,
DInDataType,
PassThrough,
InMemoryDataOperationEnum::AtomicAdd,
InVectorSize>;
struct Argument : public BaseArgument
{
Argument(const DOutDataType* p_dout,
const IndexDataType* p_indices,
DInDataType* p_din,
index_t dout_length,
const std::vector<ck::index_t>& window_lengths,
const std::vector<ck::index_t>& window_strides)
: p_dout_{p_dout},
p_indices_{p_indices},
p_din_{p_din},
blockSize_{256},
gridSize_{104}, // FIXME - Calculate the grid size by number of CU in the future
windowOverlap_{false}
{
dout_grid_desc_ = MakeDescriptor_M(dout_length, gridSize_, blockSize_);
for(size_t i = 0; i < window_lengths.size(); ++i)
{
windowOverlap_ |= window_lengths.at(i) > window_strides.at(i);
}
}
const DOutDataType* p_dout_;
const IndexDataType* p_indices_;
DInDataType* p_din_;
index_t blockSize_;
index_t gridSize_;
bool windowOverlap_;
OutGrid1dDesc dout_grid_desc_;
};
struct Invoker : public BaseInvoker
{
constexpr auto KernelSelector(bool windowOverlap)
{
if(windowOverlap)
return kernel_put_element_1d<GridwisePutElementAtomicAdd,
OutGrid1dDesc,
DOutDataType,
IndexDataType,
DInDataType,
PassThrough>;
else
return kernel_put_element_1d<GridwisePutElementSet,
OutGrid1dDesc,
DOutDataType,
IndexDataType,
DInDataType,
PassThrough>;
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel = KernelSelector(arg.windowOverlap_);
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize_),
dim3(arg.blockSize_),
0,
arg.dout_grid_desc_,
arg.p_dout_,
arg.p_indices_,
arg.p_din_,
PassThrough{});
return elapsed_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
// TODO
ignore = pArg;
return true;
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_dout,
const void* p_indices,
void* p_din,
index_t dout_length,
index_t,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> window_strides) override
{
return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
static_cast<const IndexDataType*>(p_indices),
static_cast<DInDataType*>(p_din),
dout_length,
window_lengths,
window_strides);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -96,8 +96,7 @@ struct DevicePutElementImpl
InDataType,
IndexDataType,
OutDataType,
ElementwiseOperation,
MemOp>;
ElementwiseOperation>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
......
......@@ -14,8 +14,7 @@ template <typename GridwisePutElementwise1dFunctor,
typename InDataType,
typename IndexDataType,
typename OutDataType,
typename ElementwiseOperation,
InMemoryDataOperationEnum MemOp>
typename ElementwiseOperation>
__global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc,
const InDataType* __restrict__ p_in_global,
const IndexDataType* __restrict__ p_indices_global,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment