Commit 7905cf77 authored by rocking's avatar rocking
Browse files

Extract pooling reference code

parent 7b833910
......@@ -17,111 +17,7 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType,
ck::ReduceTensorOp ReduceOpId,
bool PropagateNan,
bool OutputIndex>
static void pool_host_verify(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
Tensor<IndexDataType>& out_indices,
const std::array<ck::index_t, 2>& window_spatial_lengths,
const std::array<ck::index_t, 2>& window_strides,
const std::array<ck::index_t, 2>& in_left_pads,
const std::array<ck::index_t, 2>& /*in_right_pads*/)
{
const int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1];
using ReduceOperation = typename ck::reduce_binary_operator<ReduceOpId>::opType;
auto elementwise_ops =
ck::reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
auto in_elementwise_op = std::get<0>(elementwise_ops);
auto acc_elementwise_op = std::get<1>(elementwise_ops);
if constexpr(!OutputIndex)
{
using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
{
ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0];
for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x)
{
ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1];
if(hi >= 0 && hi < static_cast<ck::index_t>(in.mDesc.GetLengths()[2]) &&
wi >= 0 && wi < static_cast<ck::index_t>(in.mDesc.GetLengths()[3]))
{
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
}
}
}
acc_elementwise_op(accuVal, accuVal);
out(n, c, ho, wo) = accuVal;
};
make_ParallelTensorFunctor(f_nchw,
out.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2],
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
}
else
{
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0;
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
{
ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0];
for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x)
{
ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1];
if(hi >= 0 && hi < static_cast<ck::index_t>(in.mDesc.GetLengths()[2]) &&
wi >= 0 && wi < static_cast<ck::index_t>(in.mDesc.GetLengths()[3]))
{
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
IndexDataType currIndex = in.GetOffsetFromMultiIndex(n, c, hi, wi);
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
}
}
}
acc_elementwise_op(accuVal, accuVal);
out(n, c, ho, wo) = accuVal;
out_indices(n, c, ho, wo) = accuIndex;
};
make_ParallelTensorFunctor(f_nchw,
out.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2],
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
};
}
#include "ck/library/reference_tensor_operation/cpu/reference_pooling_fwd.hpp"
template <typename InDataType,
typename OutDataType,
......@@ -252,19 +148,28 @@ bool pool_test(bool do_verification,
if(do_verification)
{
pool_host_verify<InDataType,
OutDataType,
AccDataType,
IndexDataType,
ReduceOpId,
PropagateNan,
OutputIndex>(in_n_c_hi_wi,
out_n_c_ho_wo_host,
out_indices_n_c_ho_wo_host,
window_spatial_lengths,
window_strides,
input_left_pads,
input_right_pads);
using ReferencePoolingFwdInstance =
ck::tensor_operation::host::ReferencePoolingFwd<4,
2,
InDataType,
OutDataType,
AccDataType,
IndexDataType,
ReduceOpId,
PropagateNan,
OutputIndex>;
auto ref_pooling = ReferencePoolingFwdInstance{};
auto ref_pooling_invoker = ref_pooling.MakeInvoker();
auto ref_pooling_argument = ref_pooling.MakeArgument(in_n_c_hi_wi,
out_n_c_ho_wo_host,
out_indices_n_c_ho_wo_host,
window_spatial_lengths,
window_strides,
input_left_pads,
input_right_pads);
ref_pooling_invoker.Run(ref_pooling_argument);
out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data());
......
......@@ -16,124 +16,7 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType,
ck::ReduceTensorOp ReduceOpId,
bool PropagateNan,
bool OutputIndex>
static void pool3d_host_verify(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
Tensor<IndexDataType>& out_indices,
const std::array<ck::index_t, 3>& window_spatial_lengths,
const std::array<ck::index_t, 3>& window_strides,
const std::array<ck::index_t, 3>& in_left_pads,
const std::array<ck::index_t, 3>& /*in_right_pads*/)
{
const int32_t reduceLength =
window_spatial_lengths[0] * window_spatial_lengths[1] * window_spatial_lengths[2];
using ReduceOperation = typename ck::reduce_binary_operator<ReduceOpId>::opType;
auto elementwise_ops =
ck::reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
auto in_elementwise_op = std::get<0>(elementwise_ops);
auto acc_elementwise_op = std::get<1>(elementwise_ops);
if constexpr(!OutputIndex)
{
using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
for(ck::index_t z = 0; z < window_spatial_lengths[0]; ++z)
{
ck::index_t di = do_ * window_strides[0] + z - in_left_pads[0];
for(ck::index_t y = 0; y < window_spatial_lengths[1]; ++y)
{
ck::index_t hi = ho * window_strides[1] + y - in_left_pads[1];
for(ck::index_t x = 0; x < window_spatial_lengths[2]; ++x)
{
ck::index_t wi = wo * window_strides[2] + x - in_left_pads[2];
if(di >= 0 && di < static_cast<ck::index_t>(in.mDesc.GetLengths()[2]) &&
hi >= 0 && hi < static_cast<ck::index_t>(in.mDesc.GetLengths()[3]) &&
wi >= 0 && wi < static_cast<ck::index_t>(in.mDesc.GetLengths()[4]))
{
AccDataType currVal = static_cast<AccDataType>(in(n, c, di, hi, wi));
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
}
}
}
}
acc_elementwise_op(accuVal, accuVal);
out(n, c, do_, ho, wo) = accuVal;
};
make_ParallelTensorFunctor(f_ncdhw,
out.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2],
out.mDesc.GetLengths()[3],
out.mDesc.GetLengths()[4])(std::thread::hardware_concurrency());
}
else
{
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0;
for(ck::index_t z = 0; z < window_spatial_lengths[0]; ++z)
{
ck::index_t di = do_ * window_strides[0] + z - in_left_pads[0];
for(ck::index_t y = 0; y < window_spatial_lengths[1]; ++y)
{
ck::index_t hi = ho * window_strides[1] + y - in_left_pads[1];
for(ck::index_t x = 0; x < window_spatial_lengths[2]; ++x)
{
ck::index_t wi = wo * window_strides[2] + x - in_left_pads[2];
if(di >= 0 && di < static_cast<ck::index_t>(in.mDesc.GetLengths()[2]) &&
hi >= 0 && hi < static_cast<ck::index_t>(in.mDesc.GetLengths()[3]) &&
wi >= 0 && wi < static_cast<ck::index_t>(in.mDesc.GetLengths()[4]))
{
AccDataType currVal = static_cast<AccDataType>(in(n, c, di, hi, wi));
IndexDataType currIndex = in.GetOffsetFromMultiIndex(n, c, di, hi, wi);
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
}
}
}
}
acc_elementwise_op(accuVal, accuVal);
out(n, c, do_, ho, wo) = accuVal;
out_indices(n, c, do_, ho, wo) = accuIndex;
};
make_ParallelTensorFunctor(f_ncdhw,
out.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2],
out.mDesc.GetLengths()[3],
out.mDesc.GetLengths()[4])(std::thread::hardware_concurrency());
};
}
#include "ck/library/reference_tensor_operation/cpu/reference_pooling_fwd.hpp"
template <typename InDataType,
typename OutDataType,
......@@ -262,19 +145,28 @@ bool pool3d_test(bool do_verification,
if(do_verification)
{
pool3d_host_verify<InDataType,
OutDataType,
AccDataType,
IndexDataType,
ReduceOpId,
PropagateNan,
OutputIndex>(in_n_c_di_hi_wi,
out_n_c_do_ho_wo_host,
out_indices_n_c_do_ho_wo_host,
window_spatial_lengths,
window_strides,
input_left_pads,
input_right_pads);
using ReferencePoolingFwdInstance =
ck::tensor_operation::host::ReferencePoolingFwd<5,
3,
InDataType,
OutDataType,
AccDataType,
IndexDataType,
ReduceOpId,
PropagateNan,
OutputIndex>;
auto ref_pooling = ReferencePoolingFwdInstance{};
auto ref_pooling_invoker = ref_pooling.MakeInvoker();
auto ref_pooling_argument = ref_pooling.MakeArgument(in_n_c_di_hi_wi,
out_n_c_do_ho_wo_host,
out_indices_n_c_do_ho_wo_host,
window_spatial_lengths,
window_strides,
input_left_pads,
input_right_pads);
ref_pooling_invoker.Run(ref_pooling_argument);
out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data());
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <index_t InOutRank,
index_t WindowRank,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType,
ck::ReduceTensorOp ReduceOpId,
bool PropagateNan,
bool OutputIndex>
struct ReferencePoolingFwd : public device::BaseOperator
{
using ReduceOperation = typename ck::reduce_binary_operator<ReduceOpId>::opType;
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
Tensor<IndexDataType>& out_indices,
const std::array<ck::index_t, WindowRank>& window_spatial_lengths,
const std::array<ck::index_t, WindowRank>& window_strides,
const std::array<ck::index_t, WindowRank>& in_left_pads,
const std::array<ck::index_t, WindowRank>& /*in_right_pads*/)
: in_(in),
out_(out),
out_indices_(out_indices),
window_spatial_lengths_(window_spatial_lengths),
window_strides_(window_strides),
in_left_pads_(in_left_pads),
reduceLength_(1)
{
static_for<0, WindowRank, 1>{}(
[&](auto I) { reduceLength_ *= window_spatial_lengths[I]; });
}
const Tensor<InDataType>& in_;
Tensor<OutDataType>& out_;
Tensor<IndexDataType>& out_indices_;
const std::array<ck::index_t, WindowRank>& window_spatial_lengths_;
const std::array<ck::index_t, WindowRank>& window_strides_;
const std::array<ck::index_t, WindowRank>& in_left_pads_;
int reduceLength_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
float RunPooling3dFwd(const Argument& arg)
{
auto elementwise_ops =
ck::reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
arg.reduceLength_);
auto in_elementwise_op = std::get<0>(elementwise_ops);
auto acc_elementwise_op = std::get<1>(elementwise_ops);
if constexpr(!OutputIndex)
{
using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{
ck::index_t wi =
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2];
if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]) &&
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{
AccDataType currVal =
static_cast<AccDataType>(arg.in_(n, c, di, hi, wi));
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
}
}
}
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, do_, ho, wo) = accuVal;
};
make_ParallelTensorFunctor(f_ncdhw,
arg.out_.mDesc.GetLengths()[0],
arg.out_.mDesc.GetLengths()[1],
arg.out_.mDesc.GetLengths()[2],
arg.out_.mDesc.GetLengths()[3],
arg.out_.mDesc.GetLengths()[4])(
std::thread::hardware_concurrency());
}
else
{
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0;
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{
ck::index_t wi =
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2];
if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]) &&
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{
AccDataType currVal =
static_cast<AccDataType>(arg.in_(n, c, di, hi, wi));
IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, di, hi, wi);
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
}
}
}
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, do_, ho, wo) = accuVal;
arg.out_indices_(n, c, do_, ho, wo) = accuIndex;
};
make_ParallelTensorFunctor(f_ncdhw,
arg.out_.mDesc.GetLengths()[0],
arg.out_.mDesc.GetLengths()[1],
arg.out_.mDesc.GetLengths()[2],
arg.out_.mDesc.GetLengths()[3],
arg.out_.mDesc.GetLengths()[4])(
std::thread::hardware_concurrency());
};
return 0;
}
float RunPooling2dFwd(const Argument& arg)
{
auto elementwise_ops =
ck::reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
arg.reduceLength_);
auto in_elementwise_op = std::get<0>(elementwise_ops);
auto acc_elementwise_op = std::get<1>(elementwise_ops);
if constexpr(!OutputIndex)
{
using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
{
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1];
if(hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{
AccDataType currVal =
static_cast<AccDataType>(arg.in_(n, c, hi, wi));
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
}
}
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, ho, wo) = accuVal;
};
make_ParallelTensorFunctor(f_nchw,
arg.out_.mDesc.GetLengths()[0],
arg.out_.mDesc.GetLengths()[1],
arg.out_.mDesc.GetLengths()[2],
arg.out_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
}
else
{
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0;
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
{
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1];
if(hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{
AccDataType currVal =
static_cast<AccDataType>(arg.in_(n, c, hi, wi));
IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, hi, wi);
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
}
}
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, ho, wo) = accuVal;
arg.out_indices_(n, c, ho, wo) = accuIndex;
};
make_ParallelTensorFunctor(f_nchw,
arg.out_.mDesc.GetLengths()[0],
arg.out_.mDesc.GetLengths()[1],
arg.out_.mDesc.GetLengths()[2],
arg.out_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
};
return 0;
}
float Run(const Argument& arg)
{
// TODO - support generic pooling
if constexpr(InOutRank == 5 && WindowRank == 3)
return RunPooling3dFwd(arg);
else if constexpr(InOutRank == 4 && WindowRank == 2)
return RunPooling2dFwd(arg);
else
throw std::runtime_error("Only support pooling3d or pooling2d so far");
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
Tensor<IndexDataType>& out_indices,
const std::array<ck::index_t, WindowRank>& window_spatial_lengths,
const std::array<ck::index_t, WindowRank>& window_strides,
const std::array<ck::index_t, WindowRank>& in_left_pads,
const std::array<ck::index_t, WindowRank>& in_right_pads)
{
return Argument{in,
out,
out_indices,
window_spatial_lengths,
window_strides,
in_left_pads,
in_right_pads};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferencePoolingFwd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment