Commit cab21510 authored by Jing Zhang's avatar Jing Zhang
Browse files

simplify host conv forward

parent f5ae909b
......@@ -15,7 +15,7 @@
//#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
//#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1
......@@ -65,7 +65,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
constexpr auto I1 = Number<1>{};
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
double v = 0;
float v = 0;
for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
{
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
......@@ -77,33 +77,18 @@ void host_convolution_forward(const Tensor<TIn>& in,
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3])
{
if constexpr(is_same<TIn, ushort>::value)
{
v += ck::type_convert<float>(in(n, c, hi, wi)) *
ck::type_convert<float>(wei(k, c, y, x));
}
else
{
v += static_cast<const double>(in(n, c, hi, wi)) *
static_cast<const double>(wei(k, c, y, x));
}
v += ck::type_convert<float>(in(n, c, hi, wi)) *
ck::type_convert<float>(wei(k, c, y, x));
}
}
}
}
if constexpr(is_same<TOut, ushort>::value)
{
out(n, k, ho, wo) = ck::type_convert<ushort>(static_cast<float>(v));
}
else
{
out(n, k, ho, wo) = v;
}
out(n, k, ho, wo) = ck::type_convert<TOut>(v);
};
auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) {
double v = 0;
float v = 0;
for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c)
{
for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y)
......@@ -115,28 +100,14 @@ void host_convolution_forward(const Tensor<TIn>& in,
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
wi < in.mDesc.GetLengths()[2])
{
if constexpr(is_same<TIn, ushort>::value)
{
v += ck::type_convert<float>(in(n, hi, wi, c)) *
ck::type_convert<float>(wei(k, y, x, c));
}
else
{
v += static_cast<const double>(in(n, hi, wi, c)) *
static_cast<const double>(wei(k, y, x, c));
}
v += ck::type_convert<float>(in(n, hi, wi, c)) *
ck::type_convert<float>(wei(k, y, x, c));
}
}
}
}
if constexpr(is_same<TOut, ushort>::value)
{
out(n, ho, wo, k) = ck::type_convert<ushort>(static_cast<float>(v));
}
else
{
out(n, ho, wo, k) = v;
}
out(n, ho, wo, k) = ck::type_convert<TOut>(v);
};
if(layout == ConvTensorLayout::NCHW)
......@@ -250,7 +221,7 @@ int main(int argc, char* argv[])
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
#endif
#if 0
#if 1
using in_data_t = float;
using acc_data_t = float;
using out_data_t = float;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment