Commit a3b4c5cb authored by wangshaojie6's avatar wangshaojie6
Browse files

merge develop branch and add gridwise pipeline v3

parents 48918ab9 1677cf70
...@@ -84,7 +84,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -84,7 +84,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
return 0; return 0;
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#pragma once
#include <iostream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// FIXME: support arbitrary elementwise operation for A/B/C
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct ReferenceCGemm : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k_real,
const Tensor<ADataType>& a_m_k_imag,
const Tensor<BDataType>& b_k_n_real,
const Tensor<BDataType>& b_k_n_imag,
Tensor<CDataType>& c_m_n_real,
Tensor<CDataType>& c_m_n_imag,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_real_{a_m_k_real},
a_m_k_imag_{a_m_k_imag},
b_k_n_real_{b_k_n_real},
b_k_n_imag_{b_k_n_imag},
c_m_n_real_{c_m_n_real},
c_m_n_imag_{c_m_n_imag},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_real_;
const Tensor<ADataType>& a_m_k_imag_;
const Tensor<BDataType>& b_k_n_real_;
const Tensor<BDataType>& b_k_n_imag_;
Tensor<CDataType>& c_m_n_real_;
Tensor<CDataType>& c_m_n_imag_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceCGemm::Argument;
float Run(const Argument& arg)
{
const std::size_t K = arg.a_m_k_real_.mDesc.GetLengths()[1];
if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1])
{
throw std::runtime_error("wrong! Incompatible real and imag sizes in CGEMM");
}
auto f_mk_kn_mn_real = [&](auto m, auto n) {
float v_c_real = 0;
for(std::size_t k = 0; k < K; ++k)
{
float v_a_real = ck::type_convert<float>(arg.a_m_k_real_(m, k));
float v_a_imag = ck::type_convert<float>(arg.a_m_k_imag_(m, k));
float v_b_real = ck::type_convert<float>(arg.b_k_n_real_(k, n));
float v_b_imag = ck::type_convert<float>(arg.b_k_n_imag_(k, n));
v_c_real += v_a_real * v_b_real - v_a_imag * v_b_imag;
}
arg.c_m_n_real_(m, n) = v_c_real;
};
auto f_mk_kn_mn_imag = [&](auto m, auto n) {
float v_c_imag = 0;
for(std::size_t k = 0; k < K; ++k)
{
float v_a_real = ck::type_convert<float>(arg.a_m_k_real_(m, k));
float v_a_imag = ck::type_convert<float>(arg.a_m_k_imag_(m, k));
float v_b_real = ck::type_convert<float>(arg.b_k_n_real_(k, n));
float v_b_imag = ck::type_convert<float>(arg.b_k_n_imag_(k, n));
v_c_imag += v_a_real * v_b_imag + v_a_imag * v_b_real;
}
arg.c_m_n_imag_(m, n) = v_c_imag;
};
make_ParallelTensorFunctor(f_mk_kn_mn_real,
arg.c_m_n_real_.mDesc.GetLengths()[0],
arg.c_m_n_real_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_mk_kn_mn_imag,
arg.c_m_n_imag_.mDesc.GetLengths()[0],
arg.c_m_n_imag_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k_real,
const Tensor<ADataType>& a_m_k_imag,
const Tensor<BDataType>& b_k_n_real,
const Tensor<BDataType>& b_k_n_imag,
Tensor<CDataType>& c_m_n_real,
Tensor<CDataType>& c_m_n_imag,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_m_k_real,
a_m_k_imag,
b_k_n_real,
b_k_n_imag,
c_m_n_real,
c_m_n_imag,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceCGemm"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
#ifndef REFERENCE_CONV_WRW_HPP #pragma once
#define REFERENCE_CONV_WRW_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -16,7 +15,9 @@ template <typename InDataType, ...@@ -16,7 +15,9 @@ template <typename InDataType,
typename OutDataType, typename OutDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2,
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdWeight : public device::BaseOperator struct ReferenceConvBwdWeight : public device::BaseOperator
{ {
// Argument // Argument
...@@ -32,9 +33,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -32,9 +33,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) OutElementwiseOperation out_element_op)
: in_n_c_hi_wi_{in_n_c_hi_wi}, : input_{in_n_c_hi_wi},
wei_k_c_y_x_{wei_k_c_y_x}, weight_{wei_k_c_y_x},
out_n_k_ho_wo_{out_n_k_ho_wo}, output_{out_n_k_ho_wo},
conv_strides_{conv_filter_strides}, conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations}, conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads}, in_left_pads_{input_left_pads},
...@@ -45,9 +46,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -45,9 +46,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
{ {
} }
const Tensor<InDataType>& in_n_c_hi_wi_; const Tensor<InDataType>& input_;
Tensor<WeiDataType>& wei_k_c_y_x_; Tensor<WeiDataType>& weight_;
const Tensor<OutDataType>& out_n_k_ho_wo_; const Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_; std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<index_t> conv_dilations_;
...@@ -66,55 +67,184 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -66,55 +67,184 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
constexpr auto I0 = Number<0>{}; if constexpr(NumDimSpatial == 1)
constexpr auto I1 = Number<1>{}; {
auto f_kcyx = [&](auto k, auto c, auto y, auto x) { constexpr auto I0 = Number<0>{};
float v_acc = 0; auto f_kcx = [&](auto k, auto c, auto x) {
for(int n = 0; n < arg.out_n_k_ho_wo_.mDesc.GetLengths()[0]; ++n) float v_acc = 0;
{ for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(int ho = 0; ho < arg.out_n_k_ho_wo_.mDesc.GetLengths()[2]; ++ho)
{ {
int hi = ho * arg.conv_strides_[I0] + y * arg.conv_dilations_[I0] - for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[2]; ++wo)
arg.in_left_pads_[I0];
for(int wo = 0; wo < arg.out_n_k_ho_wo_.mDesc.GetLengths()[3]; ++wo)
{ {
int wi = wo * arg.conv_strides_[I1] + x * arg.conv_dilations_[I1] - auto wi =
arg.in_left_pads_[I1]; ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I0]) +
if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[I0]) -
wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2])
{ {
float v_out; float v_out;
float v_in; float v_in;
arg.out_element_op_( arg.out_element_op_(v_out,
v_out, ck::type_convert<float>(arg.output_(n, k, wo)));
ck::type_convert<float>(arg.out_n_k_ho_wo_(n, k, ho, wo))); arg.in_element_op_(v_in,
arg.in_element_op_( ck::type_convert<float>(arg.input_(n, c, wi)));
v_in, ck::type_convert<float>(arg.in_n_c_hi_wi_(n, c, hi, wi)));
v_acc += v_out * v_in; v_acc += v_out * v_in;
} }
} }
} }
} float v_wei;
float v_wei;
arg.wei_element_op_(v_wei, v_acc); arg.wei_element_op_(v_wei, v_acc);
arg.wei_k_c_y_x_(k, c, y, x) = ck::type_convert<OutDataType>(v_wei); arg.weight_(k, c, x) = ck::type_convert<WeiDataType>(v_wei);
}; };
make_ParallelTensorFunctor(f_kcyx, make_ParallelTensorFunctor(f_kcx,
arg.wei_k_c_y_x_.mDesc.GetLengths()[0], arg.weight_.mDesc.GetLengths()[0],
arg.wei_k_c_y_x_.mDesc.GetLengths()[1], arg.weight_.mDesc.GetLengths()[1],
arg.wei_k_c_y_x_.mDesc.GetLengths()[2], arg.weight_.mDesc.GetLengths()[2])(
arg.wei_k_c_y_x_.mDesc.GetLengths()[3])( std::thread::hardware_concurrency());
std::thread::hardware_concurrency());
return 0; return 0;
}
else if constexpr(NumDimSpatial == 2)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
{
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[2]; ++ho)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[3]; ++wo)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I1]) +
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[I1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[3])
{
float v_out;
float v_in;
arg.out_element_op_(
v_out, ck::type_convert<float>(arg.output_(n, k, ho, wo)));
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi)));
v_acc += v_out * v_in;
}
}
}
}
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, y, x) = ck::type_convert<WeiDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kcyx,
arg.weight_.mDesc.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
auto f_kczyx = [&](auto k, auto c, auto z, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
{
for(std::size_t do_ = 0; do_ < arg.output_.mDesc.GetLengths()[2]; ++do_)
{
auto di =
ck::type_convert<ck::long_index_t>(do_ * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[3]; ++ho)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I1]) +
ck::type_convert<ck::long_index_t>(y *
arg.conv_dilations_[I1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[4];
++wo)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo *
arg.conv_strides_[I2]) +
ck::type_convert<ck::long_index_t>(
x * arg.conv_dilations_[I2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I2]);
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4])
{
float v_out;
float v_in;
arg.out_element_op_(v_out,
ck::type_convert<float>(
arg.output_(n, k, do_, ho, wo)));
arg.in_element_op_(
v_in,
ck::type_convert<float>(arg.input_(n, c, di, hi, wi)));
v_acc += v_out * v_in;
}
}
}
}
}
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kczyx,
arg.weight_.mDesc.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3],
arg.weight_.mDesc.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
}
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
...@@ -174,4 +304,3 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -174,4 +304,3 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
} // namespace host } // namespace host
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -78,15 +78,18 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -78,15 +78,18 @@ struct ReferenceConvBwdData : public device::BaseOperator
AccDataType v_acc = 0; AccDataType v_acc = 0;
for(int x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
int w_tmp = wi + arg.in_left_pads_[0] - x * arg.conv_dilations_[0]; auto w_tmp = ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]);
if(w_tmp % arg.conv_strides_[0] == 0) if(w_tmp % arg.conv_strides_[0] == 0)
{ {
int wo = w_tmp / arg.conv_strides_[0]; auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
if(wo >= 0 && wo < Wo) ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(int k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; AccDataType v_out = 0;
AccDataType v_wei = 0; AccDataType v_wei = 0;
...@@ -128,24 +131,32 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -128,24 +131,32 @@ struct ReferenceConvBwdData : public device::BaseOperator
AccDataType v_acc = 0; AccDataType v_acc = 0;
for(int y = 0; y < Y; ++y) for(std::size_t y = 0; y < Y; ++y)
{ {
int h_tmp = hi + arg.in_left_pads_[0] - y * arg.conv_dilations_[0]; auto h_tmp = ck::type_convert<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]);
if(h_tmp % arg.conv_strides_[0] == 0) if(h_tmp % arg.conv_strides_[0] == 0)
{ {
int ho = h_tmp / arg.conv_strides_[0]; auto ho = ck::type_convert<ck::long_index_t>(h_tmp) /
if(ho >= 0 && ho < Ho) ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(int x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
int w_tmp = auto w_tmp =
wi + arg.in_left_pads_[1] - x * arg.conv_dilations_[1]; ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[1]);
if(w_tmp % arg.conv_strides_[1] == 0) if(w_tmp % arg.conv_strides_[1] == 0)
{ {
int wo = w_tmp / arg.conv_strides_[1]; auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
if(wo >= 0 && wo < Wo) ck::type_convert<ck::long_index_t>(
arg.conv_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(int k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; AccDataType v_out = 0;
AccDataType v_wei = 0; AccDataType v_wei = 0;
...@@ -194,33 +205,49 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -194,33 +205,49 @@ struct ReferenceConvBwdData : public device::BaseOperator
AccDataType v_acc = 0; AccDataType v_acc = 0;
for(int z = 0; z < Z; ++z) for(std::size_t z = 0; z < Z; ++z)
{ {
int d_tmp = di + arg.in_left_pads_[0] - z * arg.conv_dilations_[0]; auto d_tmp = ck::type_convert<ck::long_index_t>(di) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]);
if(d_tmp % arg.conv_strides_[0] == 0) if(d_tmp % arg.conv_strides_[0] == 0)
{ {
int do_ = d_tmp / arg.conv_strides_[0]; auto do_ = ck::type_convert<ck::long_index_t>(d_tmp) /
if(do_ >= 0 && do_ < Do) ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
{ {
for(int y = 0; y < Y; ++y) for(std::size_t y = 0; y < Y; ++y)
{ {
int h_tmp = auto h_tmp =
hi + arg.in_left_pads_[1] - y * arg.conv_dilations_[1]; ck::type_convert<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(y *
arg.conv_dilations_[1]);
if(h_tmp % arg.conv_strides_[1] == 0) if(h_tmp % arg.conv_strides_[1] == 0)
{ {
int ho = h_tmp / arg.conv_strides_[1]; auto ho = ck::type_convert<ck::long_index_t>(h_tmp) /
if(ho >= 0 && ho < Ho) ck::type_convert<ck::long_index_t>(
arg.conv_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(int x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
int w_tmp = wi + arg.in_left_pads_[2] - auto w_tmp =
x * arg.conv_dilations_[2]; ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(
arg.in_left_pads_[2]) -
ck::type_convert<ck::long_index_t>(
x * arg.conv_dilations_[2]);
if(w_tmp % arg.conv_strides_[2] == 0) if(w_tmp % arg.conv_strides_[2] == 0)
{ {
int wo = w_tmp / arg.conv_strides_[2]; auto wo =
if(wo >= 0 && wo < Wo) ck::type_convert<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(
arg.conv_strides_[2]);
if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(int k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; AccDataType v_out = 0;
AccDataType v_wei = 0; AccDataType v_wei = 0;
...@@ -264,7 +291,8 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -264,7 +291,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
......
#ifndef REFERENCE_CONV_FWD_HPP #pragma once
#define REFERENCE_CONV_FWD_HPP
#include <iostream> #include <iostream>
#include <type_traits> #include <type_traits>
#include <sstream> #include <sstream>
#include "stream_config.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -88,13 +89,16 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -88,13 +89,16 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_ncw = [&](auto n, auto k, auto wo) { auto f_ncw = [&](auto n, auto k, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{ {
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x) for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x)
{ {
int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] - auto wi =
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[0]) +
if(wi >= 0 && wi < arg.input_.mDesc.GetLengths()[2]) ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -128,18 +132,26 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -128,18 +132,26 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{ {
for(int y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y)
{ {
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - auto hi =
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x) ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x)
{ {
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - auto wi =
arg.in_left_pads_[1]; ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
if(hi >= 0 && hi < arg.input_.mDesc.GetLengths()[2] && wi >= 0 && ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
wi < arg.input_.mDesc.GetLengths()[3]) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -174,23 +186,37 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -174,23 +186,37 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{ {
for(int z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z) for(std::size_t z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z)
{ {
int di = d_o * arg.conv_strides_[0] + z * arg.conv_dilations_[0] - auto di =
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
for(int y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y) ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y)
{ {
int hi = ho * arg.conv_strides_[1] + y * arg.conv_dilations_[1] - auto hi =
arg.in_left_pads_[1]; ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[1]) +
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x) ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x)
{ {
int wi = wo * arg.conv_strides_[2] + auto wi =
x * arg.conv_dilations_[2] - arg.in_left_pads_[2]; ck::type_convert<ck::long_index_t>(wo *
if(di >= 0 && di < arg.input_.mDesc.GetLengths()[2] && arg.conv_strides_[2]) +
hi >= 0 && hi < arg.input_.mDesc.GetLengths()[3] && ck::type_convert<ck::long_index_t>(x *
wi >= 0 && wi < arg.input_.mDesc.GetLengths()[4]) arg.conv_dilations_[2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[2]);
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -226,7 +252,8 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -226,7 +252,8 @@ struct ReferenceConvFwd : public device::BaseOperator
} }
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
...@@ -286,4 +313,3 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -286,4 +313,3 @@ struct ReferenceConvFwd : public device::BaseOperator
} // namespace host } // namespace host
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -73,18 +73,25 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator ...@@ -73,18 +73,25 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c)
{ {
for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y)
{ {
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
{ {
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - auto wi =
arg.in_left_pads_[1]; ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -117,7 +124,8 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator ...@@ -117,7 +124,8 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
return 0; return 0;
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
......
...@@ -76,18 +76,25 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator ...@@ -76,18 +76,25 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c)
{ {
for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y)
{ {
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
{ {
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - auto wi =
arg.in_left_pads_[1]; ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -123,7 +130,8 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator ...@@ -123,7 +130,8 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
return 0; return 0;
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
......
#ifndef REFERENCE_GEMM_HPP #pragma once
#define REFERENCE_GEMM_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "device_base.hpp" #include "device_base.hpp"
...@@ -13,6 +11,7 @@ namespace host { ...@@ -13,6 +11,7 @@ namespace host {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
...@@ -55,20 +54,20 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -55,20 +54,20 @@ struct ReferenceGemm : public device::BaseOperator
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1]; const int K = arg.a_m_k_.mDesc.GetLengths()[1];
float v_acc = 0; AccDataType v_acc = 0;
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
float v_a; AccDataType v_a;
float v_b; AccDataType v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k))); arg.a_element_op_(v_a, static_cast<const AccDataType>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n))); arg.b_element_op_(v_b, static_cast<const AccDataType>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b; v_acc += v_a * v_b;
} }
float v_c; AccDataType v_c;
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
...@@ -82,7 +81,8 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -82,7 +81,8 @@ struct ReferenceGemm : public device::BaseOperator
return 0; return 0;
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
...@@ -129,4 +129,3 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -129,4 +129,3 @@ struct ReferenceGemm : public device::BaseOperator
} // namespace host } // namespace host
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -82,7 +82,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator ...@@ -82,7 +82,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator
return 0; return 0;
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
......
...@@ -85,7 +85,8 @@ struct ReferenceGemmBiasActivation : public device::BaseOperator ...@@ -85,7 +85,8 @@ struct ReferenceGemmBiasActivation : public device::BaseOperator
return 0; return 0;
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
......
...@@ -91,7 +91,8 @@ struct ReferenceGemmBiasActivationAdd : public device::BaseOperator ...@@ -91,7 +91,8 @@ struct ReferenceGemmBiasActivationAdd : public device::BaseOperator
return 0; return 0;
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
......
...@@ -9,26 +9,11 @@ ...@@ -9,26 +9,11 @@
#include "device_reduce_instance_blockwise_i8_i8_i8.hpp" #include "device_reduce_instance_blockwise_i8_i8_i8.hpp"
#include "device_reduce_instance_blockwise_i8_i32_i8.hpp" #include "device_reduce_instance_blockwise_i8_i32_i8.hpp"
#include "device_reduce_instance_blockwise_b16_f32_b16.hpp" #include "device_reduce_instance_blockwise_b16_f32_b16.hpp"
#include "device_reduce_instance_blockwise_second_call_f16_f16_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f64.hpp"
#include "device_reduce_instance_blockwise_second_call_i8_i8_i8.hpp"
#include "device_reduce_instance_blockwise_second_call_i32_i32_i8.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_b16.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp"
#include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.hpp"
#include "device_reduce_instance_threadwise_f16_f16_f16.hpp" #include "device_reduce_instance_threadwise_f16_f16_f16.hpp"
#include "device_reduce_instance_threadwise_f16_f32_f16.hpp" #include "device_reduce_instance_threadwise_f16_f32_f16.hpp"
#include "device_reduce_instance_threadwise_f32_f32_f32.hpp" #include "device_reduce_instance_threadwise_f32_f32_f32.hpp"
......
...@@ -3,13 +3,27 @@ ...@@ -3,13 +3,27 @@
#include "reduction_operator_mapping.hpp" #include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_impl_common.hpp" #include "device_reduce_instance_impl_common.hpp"
#include "device_reduce_blockwise.hpp" #include "device_reduce_multiblock.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
using reduce_configuration_1_instances_blockwise = std::tuple<
// clang-format off
// BlockSize | MThreadClusterSize | KThreadClusterSize
ReductionConfiguration_1<256, 128, 2>,
ReductionConfiguration_1<256, 64, 4>,
ReductionConfiguration_1<256, 32, 8>,
ReductionConfiguration_1<256, 16, 16>,
ReductionConfiguration_1<256, 8, 32>,
ReductionConfiguration_1<256, 4, 64>,
ReductionConfiguration_1<256, 2, 128>,
ReductionConfiguration_1<256, 1, 256>
// clang-format on
>;
#ifdef QUICK_REDUCE_TEST #ifdef QUICK_REDUCE_TEST
using reduce_configuration_2_instances_blockwise = std::tuple< using reduce_configuration_2_instances_blockwise = std::tuple<
// clang-format off // clang-format off
...@@ -58,8 +72,8 @@ template <typename InDataType, ...@@ -58,8 +72,8 @@ template <typename InDataType,
int Rank, int Rank,
int NumReduceDim, int NumReduceDim,
ReduceTensorOp ReduceOpId, ReduceTensorOp ReduceOpId,
NanPropagation NanOpt, bool PropagateNan,
ReduceTensorIndices IndicesOpt> bool UseIndex>
void add_device_reduce_instance_blockwise( void add_device_reduce_instance_blockwise(
std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances) std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances)
{ {
...@@ -73,92 +87,94 @@ void add_device_reduce_instance_blockwise( ...@@ -73,92 +87,94 @@ void add_device_reduce_instance_blockwise(
constexpr bool Indexable = constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
ReduceOpId == ReduceTensorOp::AMAX); ReduceOpId == ReduceTensorOp::AMAX);
constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices::NO_INDICES); constexpr bool OutputIndex = Indexable && UseIndex;
constexpr bool PropagateNan = (NanOpt == NanPropagation::NOT_PROPAGATE_NAN) ? false : true; static_for<0, std::tuple_size<reduce_configuration_1_instances_blockwise>::value, 1>{}(
[&](auto i) {
static_for<0, std::tuple_size<reduce_configuration_1_instances>::value, 1>{}([&](auto i) { using cfg1 = remove_cvref_t<decltype(
using cfg1 = std::get<i.value>(reduce_configuration_1_instances_blockwise{}))>;
remove_cvref_t<decltype(std::get<i.value>(reduce_configuration_1_instances{}))>;
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}(
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}( [&](auto j) {
[&](auto j) { using cfg2 = remove_cvref_t<decltype(
using cfg2 = remove_cvref_t<decltype( std::get<j.value>(reduce_configuration_2_instances_blockwise{}))>;
std::get<j.value>(reduce_configuration_2_instances_blockwise{}))>;
using ReduceOpInstance =
using ReduceOpInstance = DeviceReduceBlockWise<InDataType, DeviceReduceMultiBlock<InDataType,
AccDataType, AccDataType,
OutDataType, OutDataType,
Rank, Rank,
NumReduceDim, NumReduceDim,
ReduceOperation, ReduceOperation,
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
PropagateNan, InMemoryDataOperationEnum::Set,
NeedIndices, PropagateNan,
cfg1::BlockSize_, OutputIndex,
cfg1::MThreadClusterSize_, false, // HaveIndexInputIfOutputIndex
cfg1::KThreadClusterSize_, cfg1::BlockSize_,
cfg2::MThreadSliceSize_, cfg1::MThreadClusterSize_,
cfg2::KThreadSliceSize_, cfg1::KThreadClusterSize_,
cfg2::InSrcVectorDim_, cfg2::MThreadSliceSize_,
cfg2::InSrcVectorSize_, cfg2::KThreadSliceSize_,
cfg2::OutDstVectorSize_>; cfg2::InSrcVectorDim_,
cfg2::InSrcVectorSize_,
device_op_instances.push_back( cfg2::OutDstVectorSize_>;
std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
}); device_op_instances.push_back(
}); std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
});
});
}; };
#define ADD_BLOCKWISE_INST_BY_TYPE( \ #define ADD_BLOCKWISE_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
template void add_device_reduce_instance_blockwise<inT, \ template void add_device_reduce_instance_blockwise<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
NumReduceDim, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
NanOpt, \ PropagateNan, \
IndicesOpt>( \ UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances) std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID( \ #define ADD_BLOCKWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_BY_TYPE(inT, \ ADD_BLOCKWISE_INST_BY_TYPE(inT, \
compT, \ compT, \
outT, \ outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \ static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<NanPropagation>(NanOpt), \ static_cast<bool>(NanOpt), \
static_cast<ReduceTensorIndices>(IndicesOpt), \ static_cast<bool>(IndicesOpt), \
Rank, \ Rank, \
NumReduceDim) NumReduceDim)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \ #define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \ extern template void add_device_reduce_instance_blockwise<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
NumReduceDim, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
NanOpt, \ PropagateNan, \
IndicesOpt>( \ UseIndex>( \
std::vector<DeviceReducePtr< \ std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \ typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \ typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \ AccElementwiseOperation>> & \
device_op_instances) device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID( \ #define ADD_BLOCKWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \ ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
compT, \ compT, \
outT, \ outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \ static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<NanPropagation>(NanOpt), \ static_cast<bool>(NanOpt), \
static_cast<ReduceTensorIndices>(IndicesOpt), \ static_cast<bool>(IndicesOpt), \
Rank, \ Rank, \
NumReduceDim) NumReduceDim)
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP
#include "reduction_enums.hpp" #include "data_type.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP
#include "reduction_enums.hpp" #include "data_type.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP
#include "reduction_enums.hpp" #include "data_type.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment