"vscode:/vscode.git/clone" did not exist on "3848606c7ed98c585b7a41397f99e1a873b17f61"
Commit cab21510 authored by Jing Zhang's avatar Jing Zhang
Browse files

simplify host conv forward

parent f5ae909b
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
//#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" //#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" //#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
//#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" //#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" //#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1 #define USE_DYNAMIC_MODE 1
...@@ -65,7 +65,7 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -65,7 +65,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
double v = 0; float v = 0;
for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
{ {
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
...@@ -76,34 +76,19 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -76,34 +76,19 @@ void host_convolution_forward(const Tensor<TIn>& in,
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3]) wi < in.mDesc.GetLengths()[3])
{
if constexpr(is_same<TIn, ushort>::value)
{ {
v += ck::type_convert<float>(in(n, c, hi, wi)) * v += ck::type_convert<float>(in(n, c, hi, wi)) *
ck::type_convert<float>(wei(k, c, y, x)); ck::type_convert<float>(wei(k, c, y, x));
} }
else
{
v += static_cast<const double>(in(n, c, hi, wi)) *
static_cast<const double>(wei(k, c, y, x));
}
}
} }
} }
} }
if constexpr(is_same<TOut, ushort>::value) out(n, k, ho, wo) = ck::type_convert<TOut>(v);
{
out(n, k, ho, wo) = ck::type_convert<ushort>(static_cast<float>(v));
}
else
{
out(n, k, ho, wo) = v;
}
}; };
auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) { auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) {
double v = 0; float v = 0;
for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c) for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c)
{ {
for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y) for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y)
...@@ -114,29 +99,15 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -114,29 +99,15 @@ void host_convolution_forward(const Tensor<TIn>& in,
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
wi < in.mDesc.GetLengths()[2]) wi < in.mDesc.GetLengths()[2])
{
if constexpr(is_same<TIn, ushort>::value)
{ {
v += ck::type_convert<float>(in(n, hi, wi, c)) * v += ck::type_convert<float>(in(n, hi, wi, c)) *
ck::type_convert<float>(wei(k, y, x, c)); ck::type_convert<float>(wei(k, y, x, c));
} }
else
{
v += static_cast<const double>(in(n, hi, wi, c)) *
static_cast<const double>(wei(k, y, x, c));
}
}
}
} }
} }
if constexpr(is_same<TOut, ushort>::value)
{
out(n, ho, wo, k) = ck::type_convert<ushort>(static_cast<float>(v));
}
else
{
out(n, ho, wo, k) = v;
} }
out(n, ho, wo, k) = ck::type_convert<TOut>(v);
}; };
if(layout == ConvTensorLayout::NCHW) if(layout == ConvTensorLayout::NCHW)
...@@ -250,7 +221,7 @@ int main(int argc, char* argv[]) ...@@ -250,7 +221,7 @@ int main(int argc, char* argv[])
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
#endif #endif
#if 0 #if 1
using in_data_t = float; using in_data_t = float;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = float; using out_data_t = float;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment