Commit a3b4c5cb authored by wangshaojie6's avatar wangshaojie6
Browse files

merge develop branch and add gridwise pipeline v3

parents 48918ab9 1677cf70
......@@ -84,7 +84,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#pragma once
#include <iostream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// FIXME: support arbitrary elementwise operation for A/B/C
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct ReferenceCGemm : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k_real,
const Tensor<ADataType>& a_m_k_imag,
const Tensor<BDataType>& b_k_n_real,
const Tensor<BDataType>& b_k_n_imag,
Tensor<CDataType>& c_m_n_real,
Tensor<CDataType>& c_m_n_imag,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_real_{a_m_k_real},
a_m_k_imag_{a_m_k_imag},
b_k_n_real_{b_k_n_real},
b_k_n_imag_{b_k_n_imag},
c_m_n_real_{c_m_n_real},
c_m_n_imag_{c_m_n_imag},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_real_;
const Tensor<ADataType>& a_m_k_imag_;
const Tensor<BDataType>& b_k_n_real_;
const Tensor<BDataType>& b_k_n_imag_;
Tensor<CDataType>& c_m_n_real_;
Tensor<CDataType>& c_m_n_imag_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceCGemm::Argument;
float Run(const Argument& arg)
{
const std::size_t K = arg.a_m_k_real_.mDesc.GetLengths()[1];
if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1])
{
throw std::runtime_error("wrong! Incompatible real and imag sizes in CGEMM");
}
auto f_mk_kn_mn_real = [&](auto m, auto n) {
float v_c_real = 0;
for(std::size_t k = 0; k < K; ++k)
{
float v_a_real = ck::type_convert<float>(arg.a_m_k_real_(m, k));
float v_a_imag = ck::type_convert<float>(arg.a_m_k_imag_(m, k));
float v_b_real = ck::type_convert<float>(arg.b_k_n_real_(k, n));
float v_b_imag = ck::type_convert<float>(arg.b_k_n_imag_(k, n));
v_c_real += v_a_real * v_b_real - v_a_imag * v_b_imag;
}
arg.c_m_n_real_(m, n) = v_c_real;
};
auto f_mk_kn_mn_imag = [&](auto m, auto n) {
float v_c_imag = 0;
for(std::size_t k = 0; k < K; ++k)
{
float v_a_real = ck::type_convert<float>(arg.a_m_k_real_(m, k));
float v_a_imag = ck::type_convert<float>(arg.a_m_k_imag_(m, k));
float v_b_real = ck::type_convert<float>(arg.b_k_n_real_(k, n));
float v_b_imag = ck::type_convert<float>(arg.b_k_n_imag_(k, n));
v_c_imag += v_a_real * v_b_imag + v_a_imag * v_b_real;
}
arg.c_m_n_imag_(m, n) = v_c_imag;
};
make_ParallelTensorFunctor(f_mk_kn_mn_real,
arg.c_m_n_real_.mDesc.GetLengths()[0],
arg.c_m_n_real_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_mk_kn_mn_imag,
arg.c_m_n_imag_.mDesc.GetLengths()[0],
arg.c_m_n_imag_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k_real,
const Tensor<ADataType>& a_m_k_imag,
const Tensor<BDataType>& b_k_n_real,
const Tensor<BDataType>& b_k_n_imag,
Tensor<CDataType>& c_m_n_real,
Tensor<CDataType>& c_m_n_imag,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_m_k_real,
a_m_k_imag,
b_k_n_real,
b_k_n_imag,
c_m_n_real,
c_m_n_imag,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceCGemm"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
#ifndef REFERENCE_CONV_WRW_HPP
#define REFERENCE_CONV_WRW_HPP
#pragma once
#include <iostream>
#include <sstream>
......@@ -16,7 +15,9 @@ template <typename InDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2,
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdWeight : public device::BaseOperator
{
// Argument
......@@ -32,9 +33,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: in_n_c_hi_wi_{in_n_c_hi_wi},
wei_k_c_y_x_{wei_k_c_y_x},
out_n_k_ho_wo_{out_n_k_ho_wo},
: input_{in_n_c_hi_wi},
weight_{wei_k_c_y_x},
output_{out_n_k_ho_wo},
conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
......@@ -45,9 +46,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
{
}
const Tensor<InDataType>& in_n_c_hi_wi_;
Tensor<WeiDataType>& wei_k_c_y_x_;
const Tensor<OutDataType>& out_n_k_ho_wo_;
const Tensor<InDataType>& input_;
Tensor<WeiDataType>& weight_;
const Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
......@@ -66,55 +67,184 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
float Run(const Argument& arg)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
float v_acc = 0;
for(int n = 0; n < arg.out_n_k_ho_wo_.mDesc.GetLengths()[0]; ++n)
{
for(int ho = 0; ho < arg.out_n_k_ho_wo_.mDesc.GetLengths()[2]; ++ho)
if constexpr(NumDimSpatial == 1)
{
constexpr auto I0 = Number<0>{};
auto f_kcx = [&](auto k, auto c, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
{
int hi = ho * arg.conv_strides_[I0] + y * arg.conv_dilations_[I0] -
arg.in_left_pads_[I0];
for(int wo = 0; wo < arg.out_n_k_ho_wo_.mDesc.GetLengths()[3]; ++wo)
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[2]; ++wo)
{
int wi = wo * arg.conv_strides_[I1] + x * arg.conv_dilations_[I1] -
arg.in_left_pads_[I1];
if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 &&
wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2])
{
float v_out;
float v_in;
arg.out_element_op_(
v_out,
ck::type_convert<float>(arg.out_n_k_ho_wo_(n, k, ho, wo)));
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.in_n_c_hi_wi_(n, c, hi, wi)));
arg.out_element_op_(v_out,
ck::type_convert<float>(arg.output_(n, k, wo)));
arg.in_element_op_(v_in,
ck::type_convert<float>(arg.input_(n, c, wi)));
v_acc += v_out * v_in;
}
}
}
}
float v_wei;
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.wei_element_op_(v_wei, v_acc);
arg.wei_k_c_y_x_(k, c, y, x) = ck::type_convert<OutDataType>(v_wei);
};
arg.weight_(k, c, x) = ck::type_convert<WeiDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kcyx,
arg.wei_k_c_y_x_.mDesc.GetLengths()[0],
arg.wei_k_c_y_x_.mDesc.GetLengths()[1],
arg.wei_k_c_y_x_.mDesc.GetLengths()[2],
arg.wei_k_c_y_x_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_kcx,
arg.weight_.mDesc.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2])(
std::thread::hardware_concurrency());
return 0;
return 0;
}
else if constexpr(NumDimSpatial == 2)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
{
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[2]; ++ho)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[3]; ++wo)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I1]) +
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[I1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[3])
{
float v_out;
float v_in;
arg.out_element_op_(
v_out, ck::type_convert<float>(arg.output_(n, k, ho, wo)));
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi)));
v_acc += v_out * v_in;
}
}
}
}
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, y, x) = ck::type_convert<WeiDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kcyx,
arg.weight_.mDesc.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
auto f_kczyx = [&](auto k, auto c, auto z, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
{
for(std::size_t do_ = 0; do_ < arg.output_.mDesc.GetLengths()[2]; ++do_)
{
auto di =
ck::type_convert<ck::long_index_t>(do_ * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[3]; ++ho)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I1]) +
ck::type_convert<ck::long_index_t>(y *
arg.conv_dilations_[I1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[4];
++wo)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo *
arg.conv_strides_[I2]) +
ck::type_convert<ck::long_index_t>(
x * arg.conv_dilations_[I2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I2]);
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4])
{
float v_out;
float v_in;
arg.out_element_op_(v_out,
ck::type_convert<float>(
arg.output_(n, k, do_, ho, wo)));
arg.in_element_op_(
v_in,
ck::type_convert<float>(arg.input_(n, c, di, hi, wi)));
v_acc += v_out * v_in;
}
}
}
}
}
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kczyx,
arg.weight_.mDesc.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3],
arg.weight_.mDesc.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
}
}
float Run(const device::BaseArgument* p_arg, int) override
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
......@@ -174,4 +304,3 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
} // namespace host
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -78,15 +78,18 @@ struct ReferenceConvBwdData : public device::BaseOperator
AccDataType v_acc = 0;
for(int x = 0; x < X; ++x)
for(std::size_t x = 0; x < X; ++x)
{
int w_tmp = wi + arg.in_left_pads_[0] - x * arg.conv_dilations_[0];
auto w_tmp = ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]);
if(w_tmp % arg.conv_strides_[0] == 0)
{
int wo = w_tmp / arg.conv_strides_[0];
if(wo >= 0 && wo < Wo)
auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
for(int k = 0; k < K; ++k)
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_out = 0;
AccDataType v_wei = 0;
......@@ -128,24 +131,32 @@ struct ReferenceConvBwdData : public device::BaseOperator
AccDataType v_acc = 0;
for(int y = 0; y < Y; ++y)
for(std::size_t y = 0; y < Y; ++y)
{
int h_tmp = hi + arg.in_left_pads_[0] - y * arg.conv_dilations_[0];
auto h_tmp = ck::type_convert<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]);
if(h_tmp % arg.conv_strides_[0] == 0)
{
int ho = h_tmp / arg.conv_strides_[0];
if(ho >= 0 && ho < Ho)
auto ho = ck::type_convert<ck::long_index_t>(h_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{
for(int x = 0; x < X; ++x)
for(std::size_t x = 0; x < X; ++x)
{
int w_tmp =
wi + arg.in_left_pads_[1] - x * arg.conv_dilations_[1];
auto w_tmp =
ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[1]);
if(w_tmp % arg.conv_strides_[1] == 0)
{
int wo = w_tmp / arg.conv_strides_[1];
if(wo >= 0 && wo < Wo)
auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(
arg.conv_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
for(int k = 0; k < K; ++k)
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_out = 0;
AccDataType v_wei = 0;
......@@ -194,33 +205,49 @@ struct ReferenceConvBwdData : public device::BaseOperator
AccDataType v_acc = 0;
for(int z = 0; z < Z; ++z)
for(std::size_t z = 0; z < Z; ++z)
{
int d_tmp = di + arg.in_left_pads_[0] - z * arg.conv_dilations_[0];
auto d_tmp = ck::type_convert<ck::long_index_t>(di) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]);
if(d_tmp % arg.conv_strides_[0] == 0)
{
int do_ = d_tmp / arg.conv_strides_[0];
if(do_ >= 0 && do_ < Do)
auto do_ = ck::type_convert<ck::long_index_t>(d_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
{
for(int y = 0; y < Y; ++y)
for(std::size_t y = 0; y < Y; ++y)
{
int h_tmp =
hi + arg.in_left_pads_[1] - y * arg.conv_dilations_[1];
auto h_tmp =
ck::type_convert<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(y *
arg.conv_dilations_[1]);
if(h_tmp % arg.conv_strides_[1] == 0)
{
int ho = h_tmp / arg.conv_strides_[1];
if(ho >= 0 && ho < Ho)
auto ho = ck::type_convert<ck::long_index_t>(h_tmp) /
ck::type_convert<ck::long_index_t>(
arg.conv_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{
for(int x = 0; x < X; ++x)
for(std::size_t x = 0; x < X; ++x)
{
int w_tmp = wi + arg.in_left_pads_[2] -
x * arg.conv_dilations_[2];
auto w_tmp =
ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(
arg.in_left_pads_[2]) -
ck::type_convert<ck::long_index_t>(
x * arg.conv_dilations_[2]);
if(w_tmp % arg.conv_strides_[2] == 0)
{
int wo = w_tmp / arg.conv_strides_[2];
if(wo >= 0 && wo < Wo)
auto wo =
ck::type_convert<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(
arg.conv_strides_[2]);
if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo)
{
for(int k = 0; k < K; ++k)
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_out = 0;
AccDataType v_wei = 0;
......@@ -264,7 +291,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
float Run(const device::BaseArgument* p_arg, int) override
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
......
#ifndef REFERENCE_CONV_FWD_HPP
#define REFERENCE_CONV_FWD_HPP
#pragma once
#include <iostream>
#include <type_traits>
#include <sstream>
#include "stream_config.hpp"
#include "device_base.hpp"
#include "host_tensor.hpp"
......@@ -88,13 +89,16 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_ncw = [&](auto n, auto k, auto wo) {
float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x)
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x)
{
int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
if(wi >= 0 && wi < arg.input_.mDesc.GetLengths()[2])
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2])
{
float v_in;
float v_wei;
......@@ -128,18 +132,26 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{
for(int y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y)
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y)
{
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x)
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x)
{
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] -
arg.in_left_pads_[1];
if(hi >= 0 && hi < arg.input_.mDesc.GetLengths()[2] && wi >= 0 &&
wi < arg.input_.mDesc.GetLengths()[3])
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[3])
{
float v_in;
float v_wei;
......@@ -174,23 +186,37 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{
for(int z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z)
for(std::size_t z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z)
{
int di = d_o * arg.conv_strides_[0] + z * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
for(int y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y)
auto di =
ck::type_convert<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y)
{
int hi = ho * arg.conv_strides_[1] + y * arg.conv_dilations_[1] -
arg.in_left_pads_[1];
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x)
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x)
{
int wi = wo * arg.conv_strides_[2] +
x * arg.conv_dilations_[2] - arg.in_left_pads_[2];
if(di >= 0 && di < arg.input_.mDesc.GetLengths()[2] &&
hi >= 0 && hi < arg.input_.mDesc.GetLengths()[3] &&
wi >= 0 && wi < arg.input_.mDesc.GetLengths()[4])
auto wi =
ck::type_convert<ck::long_index_t>(wo *
arg.conv_strides_[2]) +
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[2]);
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4])
{
float v_in;
float v_wei;
......@@ -226,7 +252,8 @@ struct ReferenceConvFwd : public device::BaseOperator
}
}
float Run(const device::BaseArgument* p_arg, int) override
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
......@@ -286,4 +313,3 @@ struct ReferenceConvFwd : public device::BaseOperator
} // namespace host
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -73,18 +73,25 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0;
for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c)
for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c)
{
for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y)
for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y)
{
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
{
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] -
arg.in_left_pads_[1];
if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 &&
wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{
float v_in;
float v_wei;
......@@ -117,7 +124,8 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
......
......@@ -76,18 +76,25 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0;
for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c)
for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c)
{
for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y)
for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y)
{
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
{
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] -
arg.in_left_pads_[1];
if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 &&
wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{
float v_in;
float v_wei;
......@@ -123,7 +130,8 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
......
#ifndef REFERENCE_GEMM_HPP
#define REFERENCE_GEMM_HPP
#pragma once
#include <iostream>
#include <sstream>
#include "device_base.hpp"
......@@ -13,6 +11,7 @@ namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
......@@ -55,20 +54,20 @@ struct ReferenceGemm : public device::BaseOperator
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
float v_acc = 0;
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
float v_a;
float v_b;
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n)));
arg.a_element_op_(v_a, static_cast<const AccDataType>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const AccDataType>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b;
}
float v_c;
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
......@@ -82,7 +81,8 @@ struct ReferenceGemm : public device::BaseOperator
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
......@@ -129,4 +129,3 @@ struct ReferenceGemm : public device::BaseOperator
} // namespace host
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -82,7 +82,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
......
......@@ -85,7 +85,8 @@ struct ReferenceGemmBiasActivation : public device::BaseOperator
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
......
......@@ -91,7 +91,8 @@ struct ReferenceGemmBiasActivationAdd : public device::BaseOperator
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
......
......@@ -9,26 +9,11 @@
#include "device_reduce_instance_blockwise_i8_i8_i8.hpp"
#include "device_reduce_instance_blockwise_i8_i32_i8.hpp"
#include "device_reduce_instance_blockwise_b16_f32_b16.hpp"
#include "device_reduce_instance_blockwise_second_call_f16_f16_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f64.hpp"
#include "device_reduce_instance_blockwise_second_call_i8_i8_i8.hpp"
#include "device_reduce_instance_blockwise_second_call_i32_i32_i8.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_b16.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp"
#include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.hpp"
#include "device_reduce_instance_threadwise_f16_f16_f16.hpp"
#include "device_reduce_instance_threadwise_f16_f32_f16.hpp"
#include "device_reduce_instance_threadwise_f32_f32_f32.hpp"
......
......@@ -3,13 +3,27 @@
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_impl_common.hpp"
#include "device_reduce_blockwise.hpp"
#include "device_reduce_multiblock.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
using reduce_configuration_1_instances_blockwise = std::tuple<
// clang-format off
// BlockSize | MThreadClusterSize | KThreadClusterSize
ReductionConfiguration_1<256, 128, 2>,
ReductionConfiguration_1<256, 64, 4>,
ReductionConfiguration_1<256, 32, 8>,
ReductionConfiguration_1<256, 16, 16>,
ReductionConfiguration_1<256, 8, 32>,
ReductionConfiguration_1<256, 4, 64>,
ReductionConfiguration_1<256, 2, 128>,
ReductionConfiguration_1<256, 1, 256>
// clang-format on
>;
#ifdef QUICK_REDUCE_TEST
using reduce_configuration_2_instances_blockwise = std::tuple<
// clang-format off
......@@ -58,8 +72,8 @@ template <typename InDataType,
int Rank,
int NumReduceDim,
ReduceTensorOp ReduceOpId,
NanPropagation NanOpt,
ReduceTensorIndices IndicesOpt>
bool PropagateNan,
bool UseIndex>
void add_device_reduce_instance_blockwise(
std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances)
{
......@@ -73,92 +87,94 @@ void add_device_reduce_instance_blockwise(
constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
ReduceOpId == ReduceTensorOp::AMAX);
constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices::NO_INDICES);
constexpr bool PropagateNan = (NanOpt == NanPropagation::NOT_PROPAGATE_NAN) ? false : true;
static_for<0, std::tuple_size<reduce_configuration_1_instances>::value, 1>{}([&](auto i) {
using cfg1 =
remove_cvref_t<decltype(std::get<i.value>(reduce_configuration_1_instances{}))>;
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}(
[&](auto j) {
using cfg2 = remove_cvref_t<decltype(
std::get<j.value>(reduce_configuration_2_instances_blockwise{}))>;
using ReduceOpInstance = DeviceReduceBlockWise<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
NeedIndices,
cfg1::BlockSize_,
cfg1::MThreadClusterSize_,
cfg1::KThreadClusterSize_,
cfg2::MThreadSliceSize_,
cfg2::KThreadSliceSize_,
cfg2::InSrcVectorDim_,
cfg2::InSrcVectorSize_,
cfg2::OutDstVectorSize_>;
device_op_instances.push_back(
std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
});
});
constexpr bool OutputIndex = Indexable && UseIndex;
static_for<0, std::tuple_size<reduce_configuration_1_instances_blockwise>::value, 1>{}(
[&](auto i) {
using cfg1 = remove_cvref_t<decltype(
std::get<i.value>(reduce_configuration_1_instances_blockwise{}))>;
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}(
[&](auto j) {
using cfg2 = remove_cvref_t<decltype(
std::get<j.value>(reduce_configuration_2_instances_blockwise{}))>;
using ReduceOpInstance =
DeviceReduceMultiBlock<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
InMemoryDataOperationEnum::Set,
PropagateNan,
OutputIndex,
false, // HaveIndexInputIfOutputIndex
cfg1::BlockSize_,
cfg1::MThreadClusterSize_,
cfg1::KThreadClusterSize_,
cfg2::MThreadSliceSize_,
cfg2::KThreadSliceSize_,
cfg2::InSrcVectorDim_,
cfg2::InSrcVectorSize_,
cfg2::OutDstVectorSize_>;
device_op_instances.push_back(
std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
});
});
};
#define ADD_BLOCKWISE_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
#define ADD_BLOCKWISE_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<NanPropagation>(NanOpt), \
static_cast<ReduceTensorIndices>(IndicesOpt), \
Rank, \
#define ADD_BLOCKWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<bool>(NanOpt), \
static_cast<bool>(IndicesOpt), \
Rank, \
NumReduceDim)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
PropagateNan, \
UseIndex>( \
std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<NanPropagation>(NanOpt), \
static_cast<ReduceTensorIndices>(IndicesOpt), \
Rank, \
#define ADD_BLOCKWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<bool>(NanOpt), \
static_cast<bool>(IndicesOpt), \
Rank, \
NumReduceDim)
} // namespace device_reduce_instance
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "data_type.hpp"
#include "device_reduce_instance_blockwise.hpp"
namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "data_type.hpp"
#include "device_reduce_instance_blockwise.hpp"
namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "data_type.hpp"
#include "device_reduce_instance_blockwise.hpp"
namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp"
namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp"
namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp"
namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp"
namespace ck {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment