Commit 9465cf3e authored by Jing Zhang's avatar Jing Zhang
Browse files

remove bwd, seperate wrw

parent 16d78f5d
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> //#include <half.hpp>
#include "config.hpp" #include "config.hpp"
#include "print.hpp" #include "print.hpp"
#include "device.hpp" #include "device.hpp"
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "conv_common.hpp" #include "conv_common.hpp"
#include "host_conv.hpp" #include "host_conv.hpp"
#include "host_conv_wrw.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
......
...@@ -12,13 +12,6 @@ enum ConvTensorLayout ...@@ -12,13 +12,6 @@ enum ConvTensorLayout
NHWCc NHWCc
}; };
enum ConvDirection
{
Forward,
BackwardData,
BackwardWeights
};
template <typename... InDesc, template <typename... InDesc,
typename... WeiDesc, typename... WeiDesc,
typename ConvStrides, typename ConvStrides,
......
...@@ -8,233 +8,83 @@ template <typename TIn, ...@@ -8,233 +8,83 @@ template <typename TIn,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void host_direct_convolution(Tensor<TIn>& in, void host_direct_convolution(const Tensor<TIn>& in,
Tensor<TWei>& wei, const Tensor<TWei>& wei,
Tensor<TOut>& out, Tensor<TOut>& out,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads&, const InRightPads&,
const ConvTensorLayout layout = ConvTensorLayout::NCHW, const ConvTensorLayout layout = ConvTensorLayout::NCHW)
const ConvDirection dir = ConvDirection::Forward)
{ {
using namespace ck; using namespace ck;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
if(dir == ConvDirection::Forward) auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
{ double v = 0;
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
double v = 0;
for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
{
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
{
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
{
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3])
{
v += static_cast<const double>(in(n, c, hi, wi)) *
static_cast<const double>(wei(k, c, y, x));
}
}
}
}
out(n, k, ho, wo) = v;
};
auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) {
double v = 0;
for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c)
{
for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y)
{
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x)
{
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
wi < in.mDesc.GetLengths()[2])
{
v += static_cast<const double>(in(n, hi, wi, c)) *
static_cast<const double>(wei(k, y, x, c));
}
}
}
}
out(n, ho, wo, k) = v;
};
if(layout == ConvTensorLayout::NCHW)
{
make_ParallelTensorFunctor(f_nchw,
out.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2],
out.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
}
else if(layout == ConvTensorLayout::NHWC)
{ {
make_ParallelTensorFunctor(f_nhwc, for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
out.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2],
out.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error("wrong! not supported layout");
}
}
else if(dir == ConvDirection::BackwardData)
{
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
double v = 0;
for(int k = 0; k < wei.mDesc.GetLengths()[0]; ++k)
{ {
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
{ {
int ho = (hi - y * conv_dilations[I0] + in_left_pads[I0]) / conv_strides[I0]; int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3])
{ {
int wo = v += static_cast<const double>(in(n, c, hi, wi)) *
(wi - x * conv_dilations[I1] + in_left_pads[I1]) / conv_strides[I1]; static_cast<const double>(wei(k, c, y, x));
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3])
{
v += static_cast<const double>(out(n, k, ho, wo)) *
static_cast<const double>(wei(k, c, y, x));
}
} }
} }
} }
in(n, c, hi, wi) = v; }
}; out(n, k, ho, wo) = v;
};
auto f_nhwc = [&](auto n, auto hi, auto wi, auto c) { auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) {
double v = 0; double v = 0;
for(int k = 0; k < wei.mDesc.GetLengths()[0]; ++k) for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c)
{
for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y)
{ {
for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y) int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x)
{ {
int ho = (hi - y * conv_dilations[I0] + in_left_pads[I0]) / conv_strides[I0]; int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x) if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
wi < in.mDesc.GetLengths()[2])
{ {
int wo = v += static_cast<const double>(in(n, hi, wi, c)) *
(wi - x * conv_dilations[I1] + in_left_pads[I1]) / conv_strides[I1]; static_cast<const double>(wei(k, y, x, c));
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
wi < in.mDesc.GetLengths()[2])
{
v += static_cast<const double>(out(n, ho, wo, k)) *
static_cast<const double>(wei(k, y, x, c));
}
} }
} }
} }
in(n, hi, wi, c) = v;
};
if(layout == ConvTensorLayout::NCHW)
{
make_ParallelTensorFunctor(f_nchw,
in.mDesc.GetLengths()[0],
in.mDesc.GetLengths()[1],
in.mDesc.GetLengths()[2],
in.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
}
else if(layout == ConvTensorLayout::NHWC)
{
make_ParallelTensorFunctor(f_nhwc,
in.mDesc.GetLengths()[0],
in.mDesc.GetLengths()[1],
in.mDesc.GetLengths()[2],
in.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error("wrong! not supported layout");
} }
out(n, ho, wo, k) = v;
};
if(layout == ConvTensorLayout::NCHW)
{
make_ParallelTensorFunctor(f_nchw,
out.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2],
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
} }
else if(dir == ConvDirection::BackwardWeights) else if(layout == ConvTensorLayout::NHWC)
{ {
auto f_kcyx = [&](auto k, auto c, auto y, auto x) { make_ParallelTensorFunctor(f_nhwc,
double v = 0; out.mDesc.GetLengths()[0],
for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) out.mDesc.GetLengths()[1],
{ out.mDesc.GetLengths()[2],
for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho) out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
{
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
for(int wo = 0; wo < wei.mDesc.GetLengths()[3]; ++wo)
{
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3])
{
v += static_cast<const double>(in(n, c, hi, wi)) *
static_cast<const double>(out(n, k, ho, wo));
}
}
}
}
wei(k, c, y, x) = v;
};
auto f_kyxc = [&](auto k, auto y, auto x, auto c) {
double v = 0;
for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n)
{
for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho)
{
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
for(int wo = 0; wo < wei.mDesc.GetLengths()[2]; ++wo)
{
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
wi < in.mDesc.GetLengths()[2])
{
v += static_cast<const double>(in(n, hi, wi, c)) *
static_cast<const double>(out(n, ho, wo, k));
}
}
}
}
wei(k, y, x, c) = v;
};
if(layout == ConvTensorLayout::NCHW)
{
make_ParallelTensorFunctor(f_kcyx,
wei.mDesc.GetLengths()[0],
wei.mDesc.GetLengths()[1],
wei.mDesc.GetLengths()[2],
wei.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
}
else if(layout == ConvTensorLayout::NHWC)
{
make_ParallelTensorFunctor(f_kyxc,
wei.mDesc.GetLengths()[0],
wei.mDesc.GetLengths()[1],
wei.mDesc.GetLengths()[2],
wei.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error("wrong! not supported layout");
}
} }
else else
{ {
throw std::runtime_error("wrong! not supported direction"); throw std::runtime_error("wrong! not supported layout");
} }
} }
......
#pragma once
#include "host_tensor.hpp"
template <typename TIn,
typename TOut,
typename TWei,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void host_direct_convolution_backward_weights(
const Tensor<TIn>& in,
const Tensor<TOut>& out,
Tensor<TWei>& wei,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads&,
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
double v = 0;
for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n)
{
for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho)
{
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
for(int wo = 0; wo < wei.mDesc.GetLengths()[3]; ++wo)
{
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3])
{
v += static_cast<const double>(in(n, c, hi, wi)) *
static_cast<const double>(out(n, k, ho, wo));
}
}
}
}
wei(k, c, y, x) = v;
};
auto f_kyxc = [&](auto k, auto y, auto x, auto c) {
double v = 0;
for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n)
{
for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho)
{
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
for(int wo = 0; wo < wei.mDesc.GetLengths()[2]; ++wo)
{
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
wi < in.mDesc.GetLengths()[2])
{
v += static_cast<const double>(in(n, hi, wi, c)) *
static_cast<const double>(out(n, ho, wo, k));
}
}
}
}
wei(k, y, x, c) = v;
};
if(layout == ConvTensorLayout::NCHW)
{
make_ParallelTensorFunctor(f_kcyx,
wei.mDesc.GetLengths()[0],
wei.mDesc.GetLengths()[1],
wei.mDesc.GetLengths()[2],
wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
}
else if(layout == ConvTensorLayout::NHWC)
{
make_ParallelTensorFunctor(f_kyxc,
wei.mDesc.GetLengths()[0],
wei.mDesc.GetLengths()[1],
wei.mDesc.GetLengths()[2],
wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error("wrong! not supported layout");
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment