"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "580fe8951cd0d89a43cae16c24ec78c604b8dcb1"
Commit a47845f2 authored by Jing Zhang's avatar Jing Zhang
Browse files

fix

parent de6f254d
......@@ -88,7 +88,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
#endif
constexpr index_t GridSize =
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock) / (Strides{}.Get(I1) * Strides{}.Get(I0));
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
......
......@@ -16,6 +16,17 @@
#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp"
struct GeneratorTensor_0
{
template <class... Is>
double operator()(Is... is)
{
return 0;
}
};
struct GeneratorTensor_1
{
template <class... Is>
......@@ -196,12 +207,14 @@ void host_direct_convolution_back(Tensor<TOut>& in_nchw,
{
for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y)
{
int ho = (hi - y * dilation_h + h_pad_low) / stride_h;
int ho_ = (hi - y * dilation_h + h_pad_low);
int ho = ho_ / stride_h;
for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x)
{
int wo = (wi - x * dilation_w + w_pad_low) / stride_w;
if(ho >= 0 && hi < out_nkhw.mDesc.GetLengths()[2] && wo >= 0 &&
wo < out_nkhw.mDesc.GetLengths()[3] && ho % stride_h == 0 && wo % stride_w == 0)
int wo_ = (wi - x * dilation_w + w_pad_low);
int wo = wo_ / stride_w;
if(ho >= 0 && ho < out_nkhw.mDesc.GetLengths()[2] && wo >= 0 &&
wo < out_nkhw.mDesc.GetLengths()[3] && ho_ % stride_h == 0 && wo_ % stride_w == 0)
{
v += double(out_nkhw(n, k, ho, wo)) * double(wei_kcyx(k, c, y, x));
}
......@@ -489,12 +502,12 @@ int main(int argc, char* argv[])
constexpr index_t WDilation = 1;
constexpr index_t Direction = 2; //1: Forward; 2:Backward
#if 1
#if 0
constexpr index_t N = 8;
constexpr index_t C = 128;
constexpr index_t HI = 16;
constexpr index_t WI = 16;
constexpr index_t K = 128;
constexpr index_t HI = 2;
constexpr index_t WI = 32;
constexpr index_t K = 16;
constexpr index_t Y = 1;
constexpr index_t X = 1;
......@@ -706,9 +719,9 @@ int main(int argc, char* argv[])
in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
in_nchw.GenerateTensorValue(GeneratorTensor_0{}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
......@@ -785,10 +798,10 @@ int main(int argc, char* argv[])
}
#if 0
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
//LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
LogRange(std::cout << "in_nchw_host : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "in_nchw_device: ", in_nchw_device.mData, ",") << std::endl;
#endif
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment