Commit 29448ffd authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

merge from develop and revisison for pr#881

parents 9223a5e2 8f84a012
...@@ -125,7 +125,7 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -125,7 +125,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, wi) = ck::type_convert<InDataType>(v_acc); arg.input_(g, n, c, wi) = ck::type_convert<InDataType>(v_in);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(f_ncw,
...@@ -201,7 +201,7 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -201,7 +201,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, hi, wi) = ck::type_convert<InDataType>(v_acc); arg.input_(g, n, c, hi, wi) = ck::type_convert<InDataType>(v_in);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
...@@ -299,7 +299,7 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -299,7 +299,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, di, hi, wi) = ck::type_convert<InDataType>(v_acc); arg.input_(g, n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in);
}; };
make_ParallelTensorFunctor(f_ncdhw, make_ParallelTensorFunctor(f_ncdhw,
......
...@@ -92,11 +92,11 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -92,11 +92,11 @@ struct ReferenceGemm : public device::BaseOperator
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b); ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
} }
AccDataType v_c; CDataType v_c;
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
arg.c_m_n_(m, n) = ck::type_convert<CDataType>(v_c); arg.c_m_n_(m, n) = v_c;
}; };
make_ParallelTensorFunctor( make_ParallelTensorFunctor(
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <type_traits>
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
/**
* \brief Reference implementation for image to column.
*
* Tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C].
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam InputLayout Input Layout.
* \tparam InDataType Input Data Type.
* \tparam OutDataType Output Data Type.
*/
template <ck::index_t NDimSpatial,
typename InputLayout,
typename InDataType,
typename OutDataType,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceImageToColumn : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
public:
Argument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
: input_{input},
output_{output},
conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
in_right_pads_{input_right_pads},
filter_spatial_lengths_{filter_spatial_lengths}
{
initOutputSpatialLengths();
}
const Tensor<InDataType>& input_;
Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> output_spatial_lengths_;
private:
void initOutputSpatialLengths()
{
constexpr auto input_offset_to_spatial = 3;
for(ck::index_t i = 0; i < NDimSpatial; ++i)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
output_spatial_lengths_.push_back(
(input_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] +
in_right_pads_[i] - x_eff) /
conv_strides_[i] +
1);
}
}
};
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceImageToColumn::Argument;
float Run(const Argument& arg)
{
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == 2))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
const index_t N = arg.input_.GetLengths()[1];
const index_t C = arg.input_.GetLengths()[2];
if constexpr(NDimSpatial == 1)
{
const index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto n, auto wo) {
index_t row = n * Wo + wo;
index_t column = 0;
for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
{
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t c = 0; c < C; ++c)
{
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{
InDataType v_in = arg.input_(0, n, c, wi);
arg.output_(row, column) = ck::type_convert<OutDataType>(v_in);
}
column++;
}
}
};
make_ParallelTensorFunctor(func, N, Wo)(std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NDimSpatial == 2)
{
const index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto n, auto ho, auto wo) {
index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0;
for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
{
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
{
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t c = 0; c < C; ++c)
{
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{
InDataType v_in = arg.input_(0, n, c, hi, wi);
arg.output_(row, column) = ck::type_convert<OutDataType>(v_in);
}
column++;
}
}
}
};
make_ParallelTensorFunctor(func, N, Ho, Wo)(std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NDimSpatial == 3)
{
const index_t Do = arg.output_spatial_lengths_[0];
const index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto n, auto d_o, auto ho, auto wo) {
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0;
for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
{
auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
{
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x)
{
auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
for(index_t c = 0; c < C; ++c)
{
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.input_.GetLengths()[3] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.GetLengths()[4] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.GetLengths()[5])
{
InDataType v_in = arg.input_(0, n, c, di, hi, wi);
arg.output_(row, column) =
ck::type_convert<OutDataType>(v_in);
}
column++;
}
}
}
}
};
make_ParallelTensorFunctor(func, N, Do, Ho, Wo)(
std::thread::hardware_concurrency());
return 0;
}
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
using namespace tensor_layout::convolution;
if constexpr(!(std::is_same_v<InputLayout, GNWC> || std::is_same_v<InputLayout, GNHWC> ||
std::is_same_v<InputLayout, GNDHWC>))
{
return false;
}
if constexpr(!(NDimSpatial >= 1 && NDimSpatial <= 3))
{
return false;
}
return true;
}
bool IsSupportedArgument(const Argument& arg)
{
const ck::index_t G = arg.input_.GetLengths()[0];
const ck::index_t N = arg.input_.GetLengths()[1];
const ck::index_t C = arg.input_.GetLengths()[2];
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t CZYX =
C * ck::accumulate_n<index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(NDoHoWo) &&
arg.output_.GetLengths()[1] == static_cast<std::size_t>(CZYX)))
{
return false;
}
if(G != 1)
{
return false;
}
return true;
}
bool IsSupportedArgument(const device::BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
return Argument{input,
output,
filter_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceImageToColumn"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
using namespace std;
template <typename DOutDataType,
typename IndexDataType,
typename ConputeDataType,
typename DInDataType,
typename ElementwiseOperation>
struct ReferenceMaxPoolBwd : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<DOutDataType>& dout,
const Tensor<IndexDataType>& indices,
Tensor<DInDataType>& din,
ElementwiseOperation elementwise_op)
: dout_(dout), indices_(indices), din_(din), elementwise_op_(elementwise_op)
{
}
const Tensor<DOutDataType>& dout_;
const Tensor<IndexDataType>& indices_;
Tensor<DInDataType>& din_;
ElementwiseOperation elementwise_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
int din_length = arg.din_.GetElementSpaceSize();
int dout_length = arg.dout_.GetElementSpaceSize();
std::vector<ConputeDataType> buf(din_length, 0);
for(int i = 0; i < dout_length; ++i)
{
int index = arg.indices_.mData[i];
if(index >= 0 && index < din_length)
{
if constexpr(is_same_v<ConputeDataType, bhalf_t>)
{
float buf_val = ck::type_convert<float>(buf[index]);
buf_val += ck::type_convert<float>(arg.dout_.mData[i]);
buf[index] = ck::type_convert<ConputeDataType>(buf_val);
}
else
buf[index] += ck::type_convert<ConputeDataType>(arg.dout_.mData[i]);
}
}
for(int i = 0; i < din_length; ++i)
arg.din_.mData[i] = ck::type_convert<DInDataType>(buf[i]);
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<DOutDataType>& dout,
const Tensor<IndexDataType>& indices,
Tensor<DInDataType>& din,
ElementwiseOperation elementwise_op)
{
return Argument{dout, indices, din, elementwise_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceMaxPoolBwd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
...@@ -39,6 +39,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -39,6 +39,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor<IndexDataType>& out_indices, Tensor<IndexDataType>& out_indices,
const std::vector<ck::index_t>& window_spatial_lengths, const std::vector<ck::index_t>& window_spatial_lengths,
const std::vector<ck::index_t>& window_strides, const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& window_dilations,
const std::vector<ck::index_t>& in_left_pads, const std::vector<ck::index_t>& in_left_pads,
const std::vector<ck::index_t>& /*in_right_pads*/) const std::vector<ck::index_t>& /*in_right_pads*/)
: in_(in), : in_(in),
...@@ -46,6 +47,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -46,6 +47,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
out_indices_(out_indices), out_indices_(out_indices),
window_spatial_lengths_(window_spatial_lengths), window_spatial_lengths_(window_spatial_lengths),
window_strides_(window_strides), window_strides_(window_strides),
window_dilations_(window_dilations),
in_left_pads_(in_left_pads), in_left_pads_(in_left_pads),
reduceLength_(1) reduceLength_(1)
{ {
...@@ -58,6 +60,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -58,6 +60,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor<IndexDataType>& out_indices_; Tensor<IndexDataType>& out_indices_;
const std::vector<ck::index_t>& window_spatial_lengths_; const std::vector<ck::index_t>& window_spatial_lengths_;
const std::vector<ck::index_t>& window_strides_; const std::vector<ck::index_t>& window_strides_;
const std::vector<ck::index_t>& window_dilations_;
const std::vector<ck::index_t>& in_left_pads_; const std::vector<ck::index_t>& in_left_pads_;
int reduceLength_; int reduceLength_;
}; };
...@@ -85,14 +88,17 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -85,14 +88,17 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z) for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{ {
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0]; ck::index_t di = do_ * arg.window_strides_[0] +
z * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y) for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{ {
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1]; ck::index_t hi = ho * arg.window_strides_[1] +
y * arg.window_dilations_[1] - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x) for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{ {
ck::index_t wi = ck::index_t wi = wo * arg.window_strides_[2] +
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2]; x * arg.window_dilations_[2] -
arg.in_left_pads_[2];
if(di >= 0 && if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) && di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 && hi >= 0 &&
...@@ -100,8 +106,8 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -100,8 +106,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi >= 0 && wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4])) wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{ {
ComputeDataType currVal = ComputeDataType currVal = ck::type_convert<ComputeDataType>(
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi)); arg.in_(n, c, di, hi, wi));
in_elementwise_op(currVal, currVal); in_elementwise_op(currVal, currVal);
...@@ -112,7 +118,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -112,7 +118,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
} }
acc_elementwise_op(accuVal, accuVal); acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, do_, ho, wo) = accuVal; arg.out_(n, c, do_, ho, wo) = ck::type_convert<OutDataType>(accuVal);
}; };
make_ParallelTensorFunctor(f_ncdhw, make_ParallelTensorFunctor(f_ncdhw,
...@@ -136,14 +142,17 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -136,14 +142,17 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z) for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{ {
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0]; ck::index_t di = do_ * arg.window_strides_[0] +
z * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y) for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{ {
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1]; ck::index_t hi = ho * arg.window_strides_[1] +
y * arg.window_dilations_[1] - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x) for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{ {
ck::index_t wi = ck::index_t wi = wo * arg.window_strides_[2] +
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2]; x * arg.window_dilations_[2] -
arg.in_left_pads_[2];
if(di >= 0 && if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) && di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 && hi >= 0 &&
...@@ -151,8 +160,8 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -151,8 +160,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi >= 0 && wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4])) wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{ {
ComputeDataType currVal = ComputeDataType currVal = ck::type_convert<ComputeDataType>(
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi)); arg.in_(n, c, di, hi, wi));
IndexDataType currIndex = IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, di, hi, wi); arg.in_.GetOffsetFromMultiIndex(n, c, di, hi, wi);
...@@ -166,7 +175,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -166,7 +175,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
acc_elementwise_op(accuVal, accuVal); acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, do_, ho, wo) = accuVal; arg.out_(n, c, do_, ho, wo) = ck::type_convert<OutDataType>(accuVal);
arg.out_indices_(n, c, do_, ho, wo) = accuIndex; arg.out_indices_(n, c, do_, ho, wo) = accuIndex;
}; };
...@@ -202,17 +211,19 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -202,17 +211,19 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y) for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{ {
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0]; ck::index_t hi = ho * arg.window_strides_[0] +
y * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x) for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
{ {
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1]; ck::index_t wi = wo * arg.window_strides_[1] +
x * arg.window_dilations_[1] - arg.in_left_pads_[1];
if(hi >= 0 && if(hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) && hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
wi >= 0 && wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3])) wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{ {
ComputeDataType currVal = ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi)); ck::type_convert<ComputeDataType>(arg.in_(n, c, hi, wi));
in_elementwise_op(currVal, currVal); in_elementwise_op(currVal, currVal);
...@@ -222,7 +233,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -222,7 +233,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
} }
acc_elementwise_op(accuVal, accuVal); acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, ho, wo) = accuVal; arg.out_(n, c, ho, wo) = ck::type_convert<OutDataType>(accuVal);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
...@@ -245,17 +256,19 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -245,17 +256,19 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y) for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{ {
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0]; ck::index_t hi = ho * arg.window_strides_[0] +
y * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x) for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
{ {
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1]; ck::index_t wi = wo * arg.window_strides_[1] +
x * arg.window_dilations_[1] - arg.in_left_pads_[1];
if(hi >= 0 && if(hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) && hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
wi >= 0 && wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3])) wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{ {
ComputeDataType currVal = ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi)); ck::type_convert<ComputeDataType>(arg.in_(n, c, hi, wi));
IndexDataType currIndex = IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, hi, wi); arg.in_.GetOffsetFromMultiIndex(n, c, hi, wi);
...@@ -268,7 +281,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -268,7 +281,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
} }
acc_elementwise_op(accuVal, accuVal); acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, ho, wo) = accuVal; arg.out_(n, c, ho, wo) = ck::type_convert<OutDataType>(accuVal);
arg.out_indices_(n, c, ho, wo) = accuIndex; arg.out_indices_(n, c, ho, wo) = accuIndex;
}; };
...@@ -308,6 +321,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -308,6 +321,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor<IndexDataType>& out_indices, Tensor<IndexDataType>& out_indices,
const std::vector<ck::index_t>& window_spatial_lengths, const std::vector<ck::index_t>& window_spatial_lengths,
const std::vector<ck::index_t>& window_strides, const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& window_dilations,
const std::vector<ck::index_t>& in_left_pads, const std::vector<ck::index_t>& in_left_pads,
const std::vector<ck::index_t>& in_right_pads) const std::vector<ck::index_t>& in_right_pads)
{ {
...@@ -316,6 +330,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -316,6 +330,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
out_indices, out_indices,
window_spatial_lengths, window_spatial_lengths,
window_strides, window_strides,
window_dilations,
in_left_pads, in_left_pads,
in_right_pads}; in_right_pads};
} }
......
...@@ -17,6 +17,7 @@ namespace instance { ...@@ -17,6 +17,7 @@ namespace instance {
using F64 = double; using F64 = double;
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using F8 = ck::f8_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
...@@ -30,6 +31,9 @@ using F64_Tuple = ck::Tuple<F64>; ...@@ -30,6 +31,9 @@ using F64_Tuple = ck::Tuple<F64>;
using F32_Tuple = ck::Tuple<F32>; using F32_Tuple = ck::Tuple<F32>;
using I32_Tuple = ck::Tuple<I32>; using I32_Tuple = ck::Tuple<I32>;
using I32_F32_Tuple = ck::Tuple<I32, F32>; using I32_F32_Tuple = ck::Tuple<I32, F32>;
using I8_Tuple = ck::Tuple<I8>;
using F32_F32_Tuple = ck::Tuple<F32, F32>;
// GEMM layout // GEMM layout
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
...@@ -94,9 +98,11 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; ...@@ -94,9 +98,11 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using FastGelu = ck::tensor_operation::element_wise::FastGelu; using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu; using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
using Add = ck::tensor_operation::element_wise::Add;
template <typename Activation> template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>; using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment