Commit a47845f2 authored by Jing Zhang's avatar Jing Zhang
Browse files

fix

parent de6f254d
...@@ -88,7 +88,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, ...@@ -88,7 +88,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
#endif #endif
constexpr index_t GridSize = constexpr index_t GridSize =
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); ((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock) / (Strides{}.Get(I1) * Strides{}.Get(I0));
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
......
...@@ -16,6 +16,17 @@ ...@@ -16,6 +16,17 @@
#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp" #include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp" #include "device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp"
struct GeneratorTensor_0
{
template <class... Is>
double operator()(Is... is)
{
return 0;
}
};
struct GeneratorTensor_1 struct GeneratorTensor_1
{ {
template <class... Is> template <class... Is>
...@@ -196,12 +207,14 @@ void host_direct_convolution_back(Tensor<TOut>& in_nchw, ...@@ -196,12 +207,14 @@ void host_direct_convolution_back(Tensor<TOut>& in_nchw,
{ {
for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y) for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y)
{ {
int ho = (hi - y * dilation_h + h_pad_low) / stride_h; int ho_ = (hi - y * dilation_h + h_pad_low);
int ho = ho_ / stride_h;
for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x) for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x)
{ {
int wo = (wi - x * dilation_w + w_pad_low) / stride_w; int wo_ = (wi - x * dilation_w + w_pad_low);
if(ho >= 0 && hi < out_nkhw.mDesc.GetLengths()[2] && wo >= 0 && int wo = wo_ / stride_w;
wo < out_nkhw.mDesc.GetLengths()[3] && ho % stride_h == 0 && wo % stride_w == 0) if(ho >= 0 && ho < out_nkhw.mDesc.GetLengths()[2] && wo >= 0 &&
wo < out_nkhw.mDesc.GetLengths()[3] && ho_ % stride_h == 0 && wo_ % stride_w == 0)
{ {
v += double(out_nkhw(n, k, ho, wo)) * double(wei_kcyx(k, c, y, x)); v += double(out_nkhw(n, k, ho, wo)) * double(wei_kcyx(k, c, y, x));
} }
...@@ -489,12 +502,12 @@ int main(int argc, char* argv[]) ...@@ -489,12 +502,12 @@ int main(int argc, char* argv[])
constexpr index_t WDilation = 1; constexpr index_t WDilation = 1;
constexpr index_t Direction = 2; //1: Forward; 2:Backward constexpr index_t Direction = 2; //1: Forward; 2:Backward
#if 1 #if 0
constexpr index_t N = 8; constexpr index_t N = 8;
constexpr index_t C = 128; constexpr index_t C = 128;
constexpr index_t HI = 16; constexpr index_t HI = 2;
constexpr index_t WI = 16; constexpr index_t WI = 32;
constexpr index_t K = 128; constexpr index_t K = 16;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -706,9 +719,9 @@ int main(int argc, char* argv[]) ...@@ -706,9 +719,9 @@ int main(int argc, char* argv[])
in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1 #elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_0{}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); out_nkhw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 0 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
...@@ -785,10 +798,10 @@ int main(int argc, char* argv[]) ...@@ -785,10 +798,10 @@ int main(int argc, char* argv[])
} }
#if 0 #if 0
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; //LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl; LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; LogRange(std::cout << "in_nchw_host : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; LogRange(std::cout << "in_nchw_device: ", in_nchw_device.mData, ",") << std::endl;
#endif #endif
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment