Commit acd980fc authored by rocking's avatar rocking
Browse files

Add device pool bwd device op

parent afe8dae6
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/utility/reduction_functions_accumulate.hpp" #include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp" #include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -29,9 +29,7 @@ template <typename InDataType, ...@@ -29,9 +29,7 @@ template <typename InDataType,
typename DOutDataType, typename DOutDataType,
typename InLayout, typename InLayout,
typename OutLayout, typename OutLayout,
ck::ReduceTensorOp ReduceOpId, bool PropagateNan>
bool PropagateNan,
ck::InMemoryDataOperationEnum Memop>
bool maxpool_bwd_test(bool do_verification, bool maxpool_bwd_test(bool do_verification,
bool time_kernel, bool time_kernel,
ck::index_t N, ck::index_t N,
...@@ -55,7 +53,7 @@ bool maxpool_bwd_test(bool do_verification, ...@@ -55,7 +53,7 @@ bool maxpool_bwd_test(bool do_verification,
OutDataType, // OutDataType OutDataType, // OutDataType
IndexDataType, // IndexDataType IndexDataType, // IndexDataType
ComputeDataType, // ComputeDataType ComputeDataType, // ComputeDataType
ReduceOpId, ck::ReduceTensorOp::MAX,
true, // OutputIndex true, // OutputIndex
64, // BlockSize 64, // BlockSize
64, // ReduceMThreadClusterSize 64, // ReduceMThreadClusterSize
...@@ -65,7 +63,7 @@ bool maxpool_bwd_test(bool do_verification, ...@@ -65,7 +63,7 @@ bool maxpool_bwd_test(bool do_verification,
1>; // InSrcOutDstVectorSize 1>; // InSrcOutDstVectorSize
using DeviceMaxPoolBwdInstance = ck::tensor_operation::device:: using DeviceMaxPoolBwdInstance = ck::tensor_operation::device::
DevicePutElementImpl<DOutDataType, IndexDataType, DInDataType, PassThrough, Memop, 4>; DeviceIndexPoolBwdImpl<DOutDataType, IndexDataType, DInDataType, 4>;
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1;
...@@ -140,7 +138,7 @@ bool maxpool_bwd_test(bool do_verification, ...@@ -140,7 +138,7 @@ bool maxpool_bwd_test(bool do_verification,
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(indices_device_buf.GetDeviceBuffer()), static_cast<IndexDataType*>(indices_device_buf.GetDeviceBuffer()),
{N, C, Hi, Wi}, {N, C, Hi, Wi},
{Y, X}, window_spatial_lengths,
{N, C, Ho, Wo}, {N, C, Ho, Wo},
{C * Hi * Wi, 1, Wi * C, C}, {C * Hi * Wi, 1, Wi * C, C},
{C * Ho * Wo, 1, Wo * C, C}, {C * Ho * Wo, 1, Wo * C, C},
...@@ -167,7 +165,8 @@ bool maxpool_bwd_test(bool do_verification, ...@@ -167,7 +165,8 @@ bool maxpool_bwd_test(bool do_verification,
static_cast<DInDataType*>(din_device_buf.GetDeviceBuffer()), static_cast<DInDataType*>(din_device_buf.GetDeviceBuffer()),
dout_n_c_ho_wo.mDesc.GetElementSpaceSize(), dout_n_c_ho_wo.mDesc.GetElementSpaceSize(),
din_n_c_hi_wi_device.mDesc.GetElementSpaceSize(), din_n_c_hi_wi_device.mDesc.GetElementSpaceSize(),
PassThrough{}); window_spatial_lengths,
window_strides);
if(!pool_bwd.IsSupportedArgument(pool_bwd_argument_ptr.get())) if(!pool_bwd.IsSupportedArgument(pool_bwd_argument_ptr.get()))
{ {
...@@ -192,7 +191,7 @@ bool maxpool_bwd_test(bool do_verification, ...@@ -192,7 +191,7 @@ bool maxpool_bwd_test(bool do_verification,
OutDataType, OutDataType,
ComputeDataType, ComputeDataType,
IndexDataType, IndexDataType,
ReduceOpId, ck::ReduceTensorOp::MAX,
PropagateNan, PropagateNan,
true>; true>;
......
...@@ -27,24 +27,18 @@ int main() ...@@ -27,24 +27,18 @@ int main()
bool time_kernel = false; bool time_kernel = false;
// Pool shape // Pool shape
constexpr ck::index_t N = 1; ck::index_t N = 1;
constexpr ck::index_t C = 1; ck::index_t C = 1;
constexpr ck::index_t Y = 2; ck::index_t Y = 2;
constexpr ck::index_t X = 2; ck::index_t X = 2;
constexpr ck::index_t Hi = 31; ck::index_t Hi = 31;
constexpr ck::index_t Wi = 31; ck::index_t Wi = 31;
constexpr ck::index_t window_stride_h = 2; ck::index_t window_stride_h = 2;
constexpr ck::index_t window_stride_w = 2; ck::index_t window_stride_w = 2;
constexpr ck::index_t in_left_pad_h = 0; ck::index_t in_left_pad_h = 0;
constexpr ck::index_t in_left_pad_w = 0; ck::index_t in_left_pad_w = 0;
constexpr ck::index_t in_right_pad_h = 1; ck::index_t in_right_pad_h = 1;
constexpr ck::index_t in_right_pad_w = 1; ck::index_t in_right_pad_w = 1;
constexpr bool WindowOverlap = Y > window_stride_h || X > window_stride_w;
constexpr ck::InMemoryDataOperationEnum MemOp = WindowOverlap
? ck::InMemoryDataOperationEnum::AtomicAdd
: ck::InMemoryDataOperationEnum::Set;
std::cout << "WindowOverlap = " << WindowOverlap << std::endl;
bool pass = maxpool_bwd_test<InDataType, bool pass = maxpool_bwd_test<InDataType,
OutDataType, OutDataType,
...@@ -54,22 +48,20 @@ int main() ...@@ -54,22 +48,20 @@ int main()
DOutDataType, DOutDataType,
InLayout, InLayout,
OutLayout, OutLayout,
ck::ReduceTensorOp::MAX, PropagateNan>(do_verification,
PropagateNan, time_kernel,
MemOp>(do_verification, N,
time_kernel, C,
N, Y,
C, X,
Y, Hi,
X, Wi,
Hi, window_stride_h,
Wi, window_stride_w,
window_stride_h, in_left_pad_h,
window_stride_w, in_left_pad_w,
in_left_pad_h, in_right_pad_h,
in_left_pad_w, in_right_pad_w);
in_right_pad_h,
in_right_pad_w);
return (pass ? 0 : 1); return (pass ? 0 : 1);
} }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/utility/reduction_enums.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// For pooling which used indexable operation, such as MaxPool, MinPool...etc
template <typename DOutDataType, typename IndexDataType, typename DInDataType>
struct DeviceIndexPoolBwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_dout,
const void* p_indices,
void* p_din,
index_t dout_length,
index_t din_length,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> window_strides) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// output[indices] = input
template <typename DOutDataType,
typename IndexDataType,
typename DInDataType,
ck::index_t InVectorSize>
struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDataType, DInDataType>
{
static_assert(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>,
"Data type is not supported!");
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
template <typename Desc_M>
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
{
constexpr auto I0 = Number<0>{};
const auto m = desc_m.GetLength(I0);
const index_t loop_step = gridSize * blockSize * InVectorSize;
const auto pad = math::integer_least_multiple(m, loop_step) - m;
const auto desc_m_pad =
transform_tensor_descriptor(desc_m,
make_tuple(make_right_pad_transform(m, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m_pad;
}
static auto MakeDescriptor_M(index_t length, index_t gridSize, index_t blockSize)
{
const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length));
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
}
using OutGrid1dDesc = decltype(MakeDescriptor_M(1, 1, 1));
using GridwisePutElementSet = GridwisePutElement_1D<OutGrid1dDesc,
DOutDataType,
IndexDataType,
DInDataType,
PassThrough,
InMemoryDataOperationEnum::Set,
InVectorSize>;
using GridwisePutElementAtomicAdd = GridwisePutElement_1D<OutGrid1dDesc,
DOutDataType,
IndexDataType,
DInDataType,
PassThrough,
InMemoryDataOperationEnum::AtomicAdd,
InVectorSize>;
struct Argument : public BaseArgument
{
Argument(const DOutDataType* p_dout,
const IndexDataType* p_indices,
DInDataType* p_din,
index_t dout_length,
const std::vector<ck::index_t>& window_lengths,
const std::vector<ck::index_t>& window_strides)
: p_dout_{p_dout},
p_indices_{p_indices},
p_din_{p_din},
blockSize_{256},
gridSize_{104}, // FIXME - Calculate the grid size by number of CU in the future
windowOverlap_{false}
{
dout_grid_desc_ = MakeDescriptor_M(dout_length, gridSize_, blockSize_);
for(size_t i = 0; i < window_lengths.size(); ++i)
{
windowOverlap_ |= window_lengths.at(i) > window_strides.at(i);
}
}
const DOutDataType* p_dout_;
const IndexDataType* p_indices_;
DInDataType* p_din_;
index_t blockSize_;
index_t gridSize_;
bool windowOverlap_;
OutGrid1dDesc dout_grid_desc_;
};
struct Invoker : public BaseInvoker
{
constexpr auto KernelSelector(bool windowOverlap)
{
if(windowOverlap)
return kernel_put_element_1d<GridwisePutElementAtomicAdd,
OutGrid1dDesc,
DOutDataType,
IndexDataType,
DInDataType,
PassThrough>;
else
return kernel_put_element_1d<GridwisePutElementSet,
OutGrid1dDesc,
DOutDataType,
IndexDataType,
DInDataType,
PassThrough>;
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel = KernelSelector(arg.windowOverlap_);
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize_),
dim3(arg.blockSize_),
0,
arg.dout_grid_desc_,
arg.p_dout_,
arg.p_indices_,
arg.p_din_,
PassThrough{});
return elapsed_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
// TODO
ignore = pArg;
return true;
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_dout,
const void* p_indices,
void* p_din,
index_t dout_length,
index_t,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> window_strides) override
{
return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
static_cast<const IndexDataType*>(p_indices),
static_cast<DInDataType*>(p_din),
dout_length,
window_lengths,
window_strides);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -96,8 +96,7 @@ struct DevicePutElementImpl ...@@ -96,8 +96,7 @@ struct DevicePutElementImpl
InDataType, InDataType,
IndexDataType, IndexDataType,
OutDataType, OutDataType,
ElementwiseOperation, ElementwiseOperation>;
MemOp>;
float elapsed_time = launch_and_time_kernel(stream_config, float elapsed_time = launch_and_time_kernel(stream_config,
kernel, kernel,
......
...@@ -14,8 +14,7 @@ template <typename GridwisePutElementwise1dFunctor, ...@@ -14,8 +14,7 @@ template <typename GridwisePutElementwise1dFunctor,
typename InDataType, typename InDataType,
typename IndexDataType, typename IndexDataType,
typename OutDataType, typename OutDataType,
typename ElementwiseOperation, typename ElementwiseOperation>
InMemoryDataOperationEnum MemOp>
__global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc, __global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc,
const InDataType* __restrict__ p_in_global, const InDataType* __restrict__ p_in_global,
const IndexDataType* __restrict__ p_indices_global, const IndexDataType* __restrict__ p_indices_global,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment