Unverified Commit 500fa995 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Clean up conv example, Instances, profiler and test (#324)

* convnd_fwd fp16 example

* update example

* update example

* update instance

* updating refernce conv

* update reference conv

* update conv fwd profiler

* update conv 1d and 3d instance

* update include path

* clean

* update profiler for conv bwd data and weight

* update conv bwd weight

* clean

* update conv example

* update profiler for conv bwd weight

* update ckprofiler for conv bwd data

* fix reference conv bwd data bug; update conv bwd data test

* update examples

* fix initialization issue

* update test for conv fwd

* clean

* clean

* remove test case too sensitive to error threshhold

* fix test

* clean

* fix build

* adding conv multiple d

* adding conv multiple D

* add matrix padder

* add gemm padding to convnd

* adding group conv

* update gemm multi-d

* refactor

* refactor

* refactor

* clean

* clean

* refactor

* refactor

* reorg

* add ds

* add bias

* clean

* add G

* adding group

* adding group

* adding group

* update Tensor

* clean

* update example

* update DeviceGemmMultipleD_Xdl_CShuffle

* update conv bwd-data and bwd-weight

* upate contraction example

* update gemm and batch gemm with e permute

* fix example build

* instance for grouped conv1d

* update example

* adding group conv instance

* update gemm bilinear instance

* update gemm+add+add+fastgelu instance

* update profiler

* update profiler

* update test

* update test and client example

* clean

* add grouped conv into profiler

* update profiler

* clean

* add test grouped conv, update all conv test to gtest

* update test
parent 85978e02
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -8,22 +8,24 @@ ...@@ -8,22 +8,24 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X] // input descriptor in [G, N, C, Do, Ho, Wo] order
template <typename InDataType, // weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename AccDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2, typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdData : public device::BaseOperator struct ReferenceConvBwdData : public device::BaseOperator
{ {
// Argument // Argument
...@@ -73,36 +75,45 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -73,36 +75,45 @@ struct ReferenceConvBwdData : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
if constexpr(NumDimSpatial == 1) if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
{ {
auto f_ncw = [&](auto n, auto c, auto wi) { throw std::runtime_error("wrong! inconsistent dimension");
std::size_t K = arg.weight_.mDesc.GetLengths()[0]; }
std::size_t X = arg.weight_.mDesc.GetLengths()[2];
std::size_t Wo = arg.output_.mDesc.GetLengths()[2]; if constexpr(NDimSpatial == 1)
{
auto f_ncw = [&](auto g, auto n, auto c, auto wi) {
std::size_t K = arg.weight_.GetLengths()[1];
std::size_t X = arg.weight_.GetLengths()[3];
std::size_t Wo = arg.output_.GetLengths()[3];
AccDataType v_acc = 0; float v_acc = 0;
for(std::size_t x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
auto w_tmp = ck::type_convert<ck::long_index_t>(wi) + auto w_tmp = static_cast<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]); static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]);
if(w_tmp % arg.conv_strides_[0] == 0) if(w_tmp % arg.conv_strides_[0] == 0)
{ {
auto wo = ck::type_convert<ck::long_index_t>(w_tmp) / auto wo = static_cast<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]); static_cast<ck::long_index_t>(arg.conv_strides_[0]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo) if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(std::size_t k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; float v_out = 0;
AccDataType v_wei = 0; float v_wei = 0;
arg.out_element_op_( arg.out_element_op_(
v_out, v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
ck::type_convert<AccDataType>(arg.output_(n, k, wo)));
arg.wei_element_op_( arg.wei_element_op_(
v_wei, ck::type_convert<AccDataType>(arg.weight_(k, c, x))); v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
v_acc += v_out * v_wei; v_acc += v_out * v_wei;
} }
...@@ -110,66 +121,72 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -110,66 +121,72 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
} }
arg.in_element_op_(v_acc, v_acc); float v_in;
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_acc);
arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, wi) = ck::type_convert<InDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(f_ncw,
arg.input_.mDesc.GetLengths()[0], arg.input_.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1], arg.input_.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2])( arg.input_.GetLengths()[2],
arg.input_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { auto f_nchw = [&](auto g, auto n, auto c, auto hi, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0]; std::size_t K = arg.weight_.GetLengths()[1];
std::size_t Y = arg.weight_.mDesc.GetLengths()[2]; std::size_t Y = arg.weight_.GetLengths()[3];
std::size_t X = arg.weight_.mDesc.GetLengths()[3]; std::size_t X = arg.weight_.GetLengths()[4];
std::size_t Ho = arg.output_.mDesc.GetLengths()[2]; std::size_t Ho = arg.output_.GetLengths()[3];
std::size_t Wo = arg.output_.mDesc.GetLengths()[3]; std::size_t Wo = arg.output_.GetLengths()[4];
AccDataType v_acc = 0; float v_acc = 0;
for(std::size_t y = 0; y < Y; ++y) for(std::size_t y = 0; y < Y; ++y)
{ {
auto h_tmp = ck::type_convert<ck::long_index_t>(hi) + auto h_tmp = static_cast<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]); static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]);
if(h_tmp % arg.conv_strides_[0] == 0) if(h_tmp % arg.conv_strides_[0] == 0)
{ {
auto ho = ck::type_convert<ck::long_index_t>(h_tmp) / auto ho = static_cast<ck::long_index_t>(h_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]); static_cast<ck::long_index_t>(arg.conv_strides_[0]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho) if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(std::size_t x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
auto w_tmp = auto w_tmp =
ck::type_convert<ck::long_index_t>(wi) + static_cast<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) - static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(x * static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]);
arg.conv_dilations_[1]);
if(w_tmp % arg.conv_strides_[1] == 0) if(w_tmp % arg.conv_strides_[1] == 0)
{ {
auto wo = ck::type_convert<ck::long_index_t>(w_tmp) / auto wo =
ck::type_convert<ck::long_index_t>( static_cast<ck::long_index_t>(w_tmp) /
arg.conv_strides_[1]); static_cast<ck::long_index_t>(arg.conv_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo) if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(std::size_t k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; float v_out = 0;
AccDataType v_wei = 0; float v_wei = 0;
arg.out_element_op_(
v_out,
ck::type_convert<float>(
arg.output_(g, n, k, ho, wo)));
arg.out_element_op_(v_out, arg.wei_element_op_(
ck::type_convert<AccDataType>( v_wei,
arg.output_(n, k, ho, wo))); ck::type_convert<float>(
arg.wei_element_op_(v_wei, arg.weight_(g, k, c, y, x)));
ck::type_convert<AccDataType>(
arg.weight_(k, c, y, x)));
v_acc += v_out * v_wei; v_acc += v_out * v_wei;
} }
...@@ -180,90 +197,91 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -180,90 +197,91 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
} }
AccDataType v_in; float v_in;
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, hi, wi) = ck::type_convert<InDataType>(v_in);
arg.input_(g, n, c, hi, wi) = ck::type_convert<InDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
arg.input_.mDesc.GetLengths()[0], arg.input_.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1], arg.input_.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2], arg.input_.GetLengths()[2],
arg.input_.mDesc.GetLengths()[3])( arg.input_.GetLengths()[3],
arg.input_.GetLengths()[4])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) { auto f_ncdhw = [&](auto g, auto n, auto c, auto di, auto hi, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0]; std::size_t K = arg.weight_.GetLengths()[1];
std::size_t Z = arg.weight_.mDesc.GetLengths()[2]; std::size_t Z = arg.weight_.GetLengths()[3];
std::size_t Y = arg.weight_.mDesc.GetLengths()[3]; std::size_t Y = arg.weight_.GetLengths()[4];
std::size_t X = arg.weight_.mDesc.GetLengths()[4]; std::size_t X = arg.weight_.GetLengths()[5];
std::size_t Do = arg.output_.mDesc.GetLengths()[2]; std::size_t Do = arg.output_.GetLengths()[3];
std::size_t Ho = arg.output_.mDesc.GetLengths()[3]; std::size_t Ho = arg.output_.GetLengths()[4];
std::size_t Wo = arg.output_.mDesc.GetLengths()[4]; std::size_t Wo = arg.output_.GetLengths()[5];
AccDataType v_acc = 0; float v_acc = 0;
for(std::size_t z = 0; z < Z; ++z) for(std::size_t z = 0; z < Z; ++z)
{ {
auto d_tmp = ck::type_convert<ck::long_index_t>(di) + auto d_tmp = static_cast<ck::long_index_t>(di) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]); static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]);
if(d_tmp % arg.conv_strides_[0] == 0) if(d_tmp % arg.conv_strides_[0] == 0)
{ {
auto do_ = ck::type_convert<ck::long_index_t>(d_tmp) / auto do_ = static_cast<ck::long_index_t>(d_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]); static_cast<ck::long_index_t>(arg.conv_strides_[0]);
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do) if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
{ {
for(std::size_t y = 0; y < Y; ++y) for(std::size_t y = 0; y < Y; ++y)
{ {
auto h_tmp = auto h_tmp =
ck::type_convert<ck::long_index_t>(hi) + static_cast<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) - static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(y * static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]);
arg.conv_dilations_[1]);
if(h_tmp % arg.conv_strides_[1] == 0) if(h_tmp % arg.conv_strides_[1] == 0)
{ {
auto ho = ck::type_convert<ck::long_index_t>(h_tmp) / auto ho =
ck::type_convert<ck::long_index_t>( static_cast<ck::long_index_t>(h_tmp) /
arg.conv_strides_[1]); static_cast<ck::long_index_t>(arg.conv_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho) if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(std::size_t x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
auto w_tmp = auto w_tmp = static_cast<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(wi) + static_cast<ck::long_index_t>(
ck::type_convert<ck::long_index_t>( arg.in_left_pads_[2]) -
arg.in_left_pads_[2]) - static_cast<ck::long_index_t>(
ck::type_convert<ck::long_index_t>( x * arg.conv_dilations_[2]);
x * arg.conv_dilations_[2]);
if(w_tmp % arg.conv_strides_[2] == 0) if(w_tmp % arg.conv_strides_[2] == 0)
{ {
auto wo = auto wo = static_cast<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(w_tmp) / static_cast<ck::long_index_t>(
ck::type_convert<ck::long_index_t>( arg.conv_strides_[2]);
arg.conv_strides_[2]);
if(wo >= 0 && if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo) ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(std::size_t k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; float v_out = 0;
AccDataType v_wei = 0; float v_wei = 0;
arg.out_element_op_( arg.out_element_op_(
v_out, v_out,
ck::type_convert<AccDataType>( ck::type_convert<float>(arg.output_(
arg.output_( g, n, k, do_, ho, wo)));
n, k, do_, ho, wo)));
arg.wei_element_op_( arg.wei_element_op_(
v_wei, v_wei,
ck::type_convert<AccDataType>( ck::type_convert<float>(
arg.weight_(k, c, z, y, x))); arg.weight_(g, k, c, z, y, x)));
v_acc += v_out * v_wei; v_acc += v_out * v_wei;
} }
...@@ -277,17 +295,20 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -277,17 +295,20 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
} }
AccDataType v_in; float v_in;
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in);
arg.input_(g, n, c, di, hi, wi) = ck::type_convert<InDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_ncdhw, make_ParallelTensorFunctor(f_ncdhw,
arg.input_.mDesc.GetLengths()[0], arg.input_.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1], arg.input_.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2], arg.input_.GetLengths()[2],
arg.input_.mDesc.GetLengths()[3], arg.input_.GetLengths()[3],
arg.input_.mDesc.GetLengths()[4])( arg.input_.GetLengths()[4],
arg.input_.GetLengths()[5])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
......
...@@ -7,21 +7,25 @@ ...@@ -7,21 +7,25 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X] // input descriptor in [G, N, C, Do, Ho, Wo] order
template <typename InDataType, // weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2, typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdWeight : public device::BaseOperator struct ReferenceConvBwdWeight : public device::BaseOperator
{ {
// Argument // Argument
...@@ -71,156 +75,162 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -71,156 +75,162 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
if constexpr(NumDimSpatial == 1) if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
{ {
constexpr auto I0 = Number<0>{}; throw std::runtime_error("wrong! inconsistent dimension");
auto f_kcx = [&](auto k, auto c, auto x) { }
if constexpr(NDimSpatial == 1)
{
auto f_kcx = [&](auto g, auto k, auto c, auto x) {
float v_acc = 0; float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{ {
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[2]; ++wo) for(std::size_t wo = 0; wo < arg.output_.GetLengths()[3]; ++wo)
{ {
auto wi = auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I0]) + static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[I0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{ {
float v_out; float v_out;
float v_in; float v_in;
arg.out_element_op_(v_out, arg.out_element_op_(
ck::type_convert<float>(arg.output_(n, k, wo))); v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
arg.in_element_op_(v_in,
ck::type_convert<float>(arg.input_(n, c, wi))); arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
v_acc += v_out * v_in; v_acc += v_out * v_in;
} }
} }
} }
float v_wei; float v_wei;
arg.wei_element_op_(v_wei, v_acc); arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, x) = ck::type_convert<WeiDataType>(v_wei); arg.weight_(g, k, c, x) = ck::type_convert<WeiDataType>(v_wei);
}; };
make_ParallelTensorFunctor(f_kcx, make_ParallelTensorFunctor(f_kcx,
arg.weight_.mDesc.GetLengths()[0], arg.weight_.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1], arg.weight_.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2])( arg.weight_.GetLengths()[2],
arg.weight_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
constexpr auto I0 = Number<0>{}; auto f_kcyx = [&](auto g, auto k, auto c, auto y, auto x) {
constexpr auto I1 = Number<1>{};
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
float v_acc = 0; float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{ {
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[2]; ++ho) for(std::size_t ho = 0; ho < arg.output_.GetLengths()[3]; ++ho)
{ {
auto hi = auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I0]) + static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[I0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[3]; ++wo) for(std::size_t wo = 0; wo < arg.output_.GetLengths()[4]; ++wo)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I1]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(x * static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
arg.conv_dilations_[I1]) - static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
if(hi >= 0 && if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
arg.input_.mDesc.GetLengths()[3])
{ {
float v_out; float v_out;
float v_in; float v_in;
arg.out_element_op_( arg.out_element_op_(
v_out, ck::type_convert<float>(arg.output_(n, k, ho, wo))); v_out,
ck::type_convert<float>(arg.output_(g, n, k, ho, wo)));
arg.in_element_op_( arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi))); v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
v_acc += v_out * v_in; v_acc += v_out * v_in;
} }
} }
} }
} }
float v_wei; float v_wei;
arg.wei_element_op_(v_wei, v_acc); arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, y, x) = ck::type_convert<WeiDataType>(v_wei); arg.weight_(g, k, c, y, x) = ck::type_convert<WeiDataType>(v_wei);
}; };
make_ParallelTensorFunctor(f_kcyx, make_ParallelTensorFunctor(f_kcyx,
arg.weight_.mDesc.GetLengths()[0], arg.weight_.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1], arg.weight_.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2], arg.weight_.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3])( arg.weight_.GetLengths()[3],
arg.weight_.GetLengths()[4])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
constexpr auto I0 = Number<0>{}; auto f_kczyx = [&](auto g, auto k, auto c, auto z, auto y, auto x) {
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
auto f_kczyx = [&](auto k, auto c, auto z, auto y, auto x) {
float v_acc = 0; float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{ {
for(std::size_t do_ = 0; do_ < arg.output_.mDesc.GetLengths()[2]; ++do_) for(std::size_t do_ = 0; do_ < arg.output_.GetLengths()[3]; ++do_)
{ {
auto di = auto di = static_cast<ck::long_index_t>(do_ * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(do_ * arg.conv_strides_[I0]) + static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[I0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]); for(std::size_t ho = 0; ho < arg.output_.GetLengths()[4]; ++ho)
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[3]; ++ho)
{ {
auto hi = auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I1]) + static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(y * static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
arg.conv_dilations_[I1]) - static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]); for(std::size_t wo = 0; wo < arg.output_.GetLengths()[5]; ++wo)
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[4];
++wo)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
arg.conv_strides_[I2]) + static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
ck::type_convert<ck::long_index_t>( static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
x * arg.conv_dilations_[I2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I2]);
if(di >= 0 && if(di >= 0 &&
ck::type_convert<std::size_t>(di) < ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] && arg.input_.GetLengths()[3] &&
hi >= 0 && hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] && arg.input_.GetLengths()[4] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4]) arg.input_.GetLengths()[5])
{ {
float v_out; float v_out;
float v_in; float v_in;
arg.out_element_op_(v_out, arg.out_element_op_(v_out,
ck::type_convert<float>( ck::type_convert<float>(
arg.output_(n, k, do_, ho, wo))); arg.output_(g, n, k, do_, ho, wo)));
arg.in_element_op_(
v_in, arg.in_element_op_(v_in,
ck::type_convert<float>(arg.input_(n, c, di, hi, wi))); ck::type_convert<float>(
arg.input_(g, n, c, di, hi, wi)));
v_acc += v_out * v_in; v_acc += v_out * v_in;
} }
...@@ -228,19 +238,21 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -228,19 +238,21 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
} }
} }
} }
float v_wei; float v_wei;
arg.wei_element_op_(v_wei, v_acc); arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei); arg.weight_(g, k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei);
}; };
make_ParallelTensorFunctor(f_kczyx, make_ParallelTensorFunctor(f_kczyx,
arg.weight_.mDesc.GetLengths()[0], arg.weight_.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1], arg.weight_.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2], arg.weight_.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3], arg.weight_.GetLengths()[3],
arg.weight_.mDesc.GetLengths()[4])( arg.weight_.GetLengths()[4],
arg.weight_.GetLengths()[5])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -17,9 +17,10 @@ namespace host { ...@@ -17,9 +17,10 @@ namespace host {
// //
// @brief Reference implementation for forward convolution. // @brief Reference implementation for forward convolution.
// //
// @paragraph Supports both NCHW as well as NHWC formats (and their respective // @paragraph
// counterparts for weight and output) as long as tensor descriptor // Tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order
// lengths is in NCHW. // Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
// as long as dimensions in tensor descriptor is in GNCHW order
// //
// @tparam InDataType Input tensor data type. // @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type. // @tparam WeiDataType Weights tensor data type.
...@@ -28,16 +29,20 @@ namespace host { ...@@ -28,16 +29,20 @@ namespace host {
// operation. // operation.
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise // @tparam WeiElementwiseOperation Functor for weights tensor elementwise
// operation. // operation.
// @tparam NumDimSpatial Number of spatial dimensions. // @tparam NDimSpatial Number of spatial dimensions.
// //
template <typename InDataType, // input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2, typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvFwd : public device::BaseOperator struct ReferenceConvFwd : public device::BaseOperator
{ {
// Argument // Argument
...@@ -86,29 +91,37 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -86,29 +91,37 @@ struct ReferenceConvFwd : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
if constexpr(NumDimSpatial == 1) if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
{ {
auto f_ncw = [&](auto n, auto k, auto wo) { throw std::runtime_error("wrong! inconsistent dimension");
}
if constexpr(NDimSpatial == 1)
{
auto func = [&](auto g, auto n, auto k, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{ {
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x) for(std::size_t x = 0; x < arg.weight_.GetLengths()[3]; ++x)
{ {
auto wi = auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[0]) + static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
arg.in_element_op_(v_in, arg.in_element_op_(
ck::type_convert<float>(arg.input_(n, c, wi))); v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
arg.wei_element_op_(v_wei,
ck::type_convert<float>(arg.weight_(k, c, x))); arg.wei_element_op_(
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
...@@ -118,50 +131,53 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -118,50 +131,53 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out; float v_out;
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, wo) = ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(func,
arg.output_.mDesc.GetLengths()[0], arg.output_.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1], arg.output_.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2])( arg.output_.GetLengths()[2],
arg.output_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto func = [&](auto g, auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{ {
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < arg.weight_.GetLengths()[3]; ++y)
{ {
auto hi = auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) + static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x) for(std::size_t x = 0; x < arg.weight_.GetLengths()[4]; ++x)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 && if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
arg.input_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
arg.in_element_op_( arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi))); v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
arg.wei_element_op_( arg.wei_element_op_(
v_wei, ck::type_convert<float>(arg.weight_(k, c, y, x))); v_wei, ck::type_convert<float>(arg.weight_(g, k, c, y, x)));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
} }
...@@ -171,64 +187,65 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -171,64 +187,65 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out; float v_out;
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(func,
arg.output_.mDesc.GetLengths()[0], arg.output_.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1], arg.output_.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2], arg.output_.GetLengths()[2],
arg.output_.mDesc.GetLengths()[3])( arg.output_.GetLengths()[3],
arg.output_.GetLengths()[4])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) { auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{ {
for(std::size_t z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z) for(std::size_t z = 0; z < arg.weight_.GetLengths()[3]; ++z)
{ {
auto di = auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(d_o * arg.conv_strides_[0]) + static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]); for(std::size_t y = 0; y < arg.weight_.GetLengths()[4]; ++y)
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y)
{ {
auto hi = auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[1]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x) for(std::size_t x = 0; x < arg.weight_.GetLengths()[5]; ++x)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
arg.conv_strides_[2]) + static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
ck::type_convert<ck::long_index_t>(x * static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
arg.conv_dilations_[2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[2]);
if(di >= 0 && if(di >= 0 &&
ck::type_convert<std::size_t>(di) < ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] && arg.input_.GetLengths()[3] &&
hi >= 0 && hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] && arg.input_.GetLengths()[4] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4]) arg.input_.GetLengths()[5])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
arg.in_element_op_( arg.in_element_op_(v_in,
v_in, ck::type_convert<float>(
ck::type_convert<float>(arg.input_(n, c, di, hi, wi))); arg.input_(g, n, c, di, hi, wi)));
arg.wei_element_op_( arg.wei_element_op_(
v_wei, v_wei,
ck::type_convert<float>(arg.weight_(k, c, z, y, x))); ck::type_convert<float>(arg.weight_(g, k, c, z, y, x)));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
} }
...@@ -239,15 +256,17 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -239,15 +256,17 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out; float v_out;
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(func,
arg.output_.mDesc.GetLengths()[0], arg.output_.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1], arg.output_.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2], arg.output_.GetLengths()[2],
arg.output_.mDesc.GetLengths()[3], arg.output_.GetLengths()[3],
arg.output_.mDesc.GetLengths()[4])( arg.output_.GetLengths()[4],
arg.output_.GetLengths()[5])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -267,7 +286,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -267,7 +286,10 @@ struct ReferenceConvFwd : public device::BaseOperator
return true; return true;
} }
bool IsSupportedArgument(const device::BaseArgument*) override { return true; } bool IsSupportedArgument(const device::BaseArgument*) override
{
return NDimSpatial >= 1 && NDimSpatial <= 3;
}
static auto MakeArgument(const Tensor<InDataType>& input, static auto MakeArgument(const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight, const Tensor<WeiDataType>& weight,
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
#include <algorithm> #include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
#include <algorithm> #include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -10,22 +10,67 @@ namespace tensor_operation { ...@@ -10,22 +10,67 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// aliasing, for commonly used type // aliasing, for commonly used data type
using F64 = double; using F64 = double;
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using EMPTY_TUPLE = ck::Tuple<>; using Empty_Tuple = ck::Tuple<>;
using F16_TUPLE = ck::Tuple<F16>; using F16_Tuple = ck::Tuple<F16>;
using F16_F16_TUPLE = ck::Tuple<F16, F16>; using F16_F16_Tuple = ck::Tuple<F16, F16>;
using F32_TUPLE = ck::Tuple<F32>; using F32_Tuple = ck::Tuple<F32>;
// GEMM layout
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using Row_Tuple = ck::Tuple<Row>;
using Row_Row_Tuple = ck::Tuple<Row, Row>;
// Conv layout
//
using NWC = ck::tensor_layout::convolution::NWC;
using NHWC = ck::tensor_layout::convolution::NHWC;
using NDHWC = ck::tensor_layout::convolution::NDHWC;
using KXC = ck::tensor_layout::convolution::KXC;
using KYXC = ck::tensor_layout::convolution::KYXC;
using KZYXC = ck::tensor_layout::convolution::KZYXC;
using NWK = ck::tensor_layout::convolution::NWK;
using NHWK = ck::tensor_layout::convolution::NHWK;
using NDHWK = ck::tensor_layout::convolution::NDHWK;
//
using GNWC = ck::tensor_layout::convolution::GNWC;
using GNHWC = ck::tensor_layout::convolution::GNHWC;
using GNDHWC = ck::tensor_layout::convolution::GNDHWC;
using GKXC = ck::tensor_layout::convolution::GKXC;
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
using GNWK = ck::tensor_layout::convolution::GNWK;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
//
using NWGC = ck::tensor_layout::convolution::NWGC;
using NHWGC = ck::tensor_layout::convolution::NHWGC;
using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
using KXGC = ck::tensor_layout::convolution::KXGC;
using KYXGC = ck::tensor_layout::convolution::KYXGC;
using KZYXGC = ck::tensor_layout::convolution::KZYXGC;
using NWGK = ck::tensor_layout::convolution::NWGK;
using NHWGK = ck::tensor_layout::convolution::NHWGK;
using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
// pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
......
...@@ -25,7 +25,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn ...@@ -25,7 +25,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn
2, 2,
F32, F32,
F32, F32,
F32_TUPLE, F32_Tuple,
F32, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -37,7 +37,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn ...@@ -37,7 +37,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn
2, 2,
F32, F32,
F32, F32,
F32_TUPLE, F32_Tuple,
F32, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -49,7 +49,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn ...@@ -49,7 +49,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn
2, 2,
F32, F32,
F32, F32,
F32_TUPLE, F32_Tuple,
F32, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -61,7 +61,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn ...@@ -61,7 +61,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
2, 2,
F32, F32,
F32, F32,
F32_TUPLE, F32_Tuple,
F32, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
......
...@@ -25,7 +25,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instanc ...@@ -25,7 +25,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instanc
2, 2,
F32, F32,
F32, F32,
EMPTY_TUPLE, Empty_Tuple,
F32, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -37,7 +37,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instanc ...@@ -37,7 +37,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instanc
2, 2,
F32, F32,
F32, F32,
EMPTY_TUPLE, Empty_Tuple,
F32, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -49,7 +49,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instanc ...@@ -49,7 +49,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instanc
2, 2,
F32, F32,
F32, F32,
EMPTY_TUPLE, Empty_Tuple,
F32, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -61,7 +61,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc ...@@ -61,7 +61,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
2, 2,
F32, F32,
F32, F32,
EMPTY_TUPLE, Empty_Tuple,
F32, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv1d backward data
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<1,
NWC,
KXC,
NWK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(
std::vector<std::unique_ptr<
DeviceConvBwdData<1, NWC, KXC, NWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
std::vector<std::unique_ptr<
DeviceConvBwdData<1, NWC, KXC, NWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<1,
NWC,
KXC,
NWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv2d backward data
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
KYXC,
NHWK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
KYXC,
NHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
KYXC,
NHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
KYXC,
NHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv3d backward data
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3,
NDHWC,
KZYXC,
NDHWK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3,
NDHWC,
KZYXC,
NDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3,
NDHWC,
KZYXC,
NDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3,
NDHWC,
KZYXC,
NDHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBwdData<
NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceConvBwdData<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, NWC> && is_same_v<WeiLayout, KXC> &&
is_same_v<OutLayout, NWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv1d backward weight
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
NWC,
KXC,
NWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
NWC,
KXC,
NWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
NWC,
KXC,
NWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv2d backward weight
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
NHWC,
KYXC,
NHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
NHWC,
KYXC,
NHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
NHWC,
KYXC,
NHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv3d backward weight
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
NDHWC,
KZYXC,
NDHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
NDHWC,
KZYXC,
NDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
NDHWC,
KZYXC,
NDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBwdWeight<
NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceConvBwdWeight<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, NWC> && is_same_v<WeiLayout, KXC> &&
is_same_v<OutLayout, NWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv2d forward
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvFwd<2,
NHWC,
KYXC,
NHWK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvFwd<2,
NHWC,
KYXC,
NHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceConvFwd<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceConvFwd<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -19,49 +19,53 @@ namespace tensor_operation { ...@@ -19,49 +19,53 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row, Row,
Row_Row_Tuple,
Row, Row,
F16, F16,
F16, F16,
F16_F16_TUPLE, F16_F16_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
AddAddFastGelu>>>&); AddAddFastGelu>>>&);
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col, Col,
Row_Row_Tuple,
Row, Row,
F16, F16,
F16, F16,
F16_F16_TUPLE, F16_F16_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
AddAddFastGelu>>>&); AddAddFastGelu>>>&);
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col, std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row, Row,
Row_Row_Tuple,
Row, Row,
F16, F16,
F16, F16,
F16_F16_TUPLE, F16_F16_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
AddAddFastGelu>>>&); AddAddFastGelu>>>&);
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col, std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Col, Col,
Row_Row_Tuple,
Row, Row,
F16, F16,
F16, F16,
F16_F16_TUPLE, F16_F16_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -70,7 +74,9 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instanc ...@@ -70,7 +74,9 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instanc
// GEMM + Add + Add + FastGelu // GEMM + Add + Add + FastGelu
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DELayout, typename D0Layout,
typename D1Layout,
typename ELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename D0DataType, typename D0DataType,
...@@ -79,7 +85,8 @@ template <typename ALayout, ...@@ -79,7 +85,8 @@ template <typename ALayout,
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout, ALayout,
BLayout, BLayout,
DELayout, ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType, ADataType,
BDataType, BDataType,
ck::Tuple<D0DataType, D1DataType>, ck::Tuple<D0DataType, D1DataType>,
...@@ -90,7 +97,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -90,7 +97,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{ {
using DeviceOp = DeviceGemmMultipleD<ALayout, using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout, BLayout,
DELayout, ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType, ADataType,
BDataType, BDataType,
ck::Tuple<D0DataType, D1DataType>, ck::Tuple<D0DataType, D1DataType>,
...@@ -108,27 +116,31 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -108,27 +116,31 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
is_same_v<EDataType, half_t>) is_same_v<EDataType, half_t>)
{ {
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<DELayout, Row>) is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{ {
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
op_ptrs); op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<DELayout, Row>) is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{ {
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
op_ptrs); op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> && else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<DELayout, Row>) is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{ {
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
op_ptrs); op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<DELayout, Row>) is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{ {
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
op_ptrs); op_ptrs);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment