Commit 95ce725e authored by ltqin's avatar ltqin
Browse files

Merge branch 'add_host_conv_bwd_wrw' into backward_weight_v4r4r2_xdlops

parents c27a57d4 fca3500e
#pragma once #pragma once
#include "host_tensor.hpp" #include "host_tensor.hpp"
template <typename TIn, template <typename TOut,
typename TOut, typename TIn,
typename TWei, typename TWei,
typename ConvStrides, typename ConvStrides,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void host_direct_convolution_backward_weights( void host_direct_convolution_backward_weights(
const Tensor<TIn>& in,
const Tensor<TOut>& out, const Tensor<TOut>& out,
const Tensor<TIn>& in,
Tensor<TWei>& wei, Tensor<TWei>& wei,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
...@@ -29,7 +29,7 @@ void host_direct_convolution_backward_weights( ...@@ -29,7 +29,7 @@ void host_direct_convolution_backward_weights(
for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho) for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho)
{ {
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
for(int wo = 0; wo < wei.mDesc.GetLengths()[3]; ++wo) for(int wo = 0; wo < out.mDesc.GetLengths()[3]; ++wo)
{ {
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
...@@ -51,7 +51,7 @@ void host_direct_convolution_backward_weights( ...@@ -51,7 +51,7 @@ void host_direct_convolution_backward_weights(
for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho) for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho)
{ {
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
for(int wo = 0; wo < wei.mDesc.GetLengths()[2]; ++wo) for(int wo = 0; wo < out.mDesc.GetLengths()[2]; ++wo)
{ {
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment