Commit 0e0dcb38 authored by Chao Liu's avatar Chao Liu
Browse files

debugged

parent a47845f2
......@@ -49,7 +49,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr index_t N1 = 2;
constexpr index_t N2 = 4;
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
constexpr index_t B = (N * (Ho/Strides::Get(I0)) * (Wo/Strides::Get(I1))) / (N1 * N2);
#if 1
constexpr index_t BlockSize = 256;
......@@ -88,7 +88,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
#endif
constexpr index_t GridSize =
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock) / (Strides{}.Get(I1) * Strides{}.Get(I0));
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
......
......@@ -53,11 +53,11 @@ struct GeneratorTensor_3
template <class... Is>
double operator()(Is... is)
{
std::array<index_t, sizeof...(Is)> dims = {{static_cast<index_t>(is)...}};
std::array<index_t, sizeof...(Is)> multi_id = {{static_cast<index_t>(is)...}};
auto f_acc = [](auto a, auto b) { return 100 * a + b; };
auto f_acc = [](auto a, auto b) { return 10 * a + b; };
return std::accumulate(dims.begin(), dims.end(), index_t(0), f_acc);
return std::accumulate(multi_id.begin(), multi_id.end(), index_t(0), f_acc);
}
};
......@@ -66,7 +66,7 @@ struct GeneratorTensor_Checkboard
template <class... Ts>
double operator()(Ts... Xs) const
{
std::array<index_t, sizeof...(Ts)> dims = {{Xs...}};
std::array<index_t, sizeof...(Ts)> dims = {{static_cast<index_t>(Xs)...}};
return std::accumulate(dims.begin(),
dims.end(),
true,
......@@ -503,10 +503,10 @@ int main(int argc, char* argv[])
constexpr index_t Direction = 2; //1: Forward; 2:Backward
#if 0
constexpr index_t N = 8;
constexpr index_t N = 32;
constexpr index_t C = 128;
constexpr index_t HI = 2;
constexpr index_t WI = 32;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 16;
constexpr index_t Y = 1;
constexpr index_t X = 1;
......@@ -552,10 +552,10 @@ int main(int argc, char* argv[])
#elif 1
// 1x1 filter, 28x28 image
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t C = 128;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 512;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
......@@ -710,18 +710,22 @@ int main(int argc, char* argv[])
if(do_verification)
{
#if 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
in_nchw.GenerateTensorValue(GeneratorTensor_0{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
in_nchw.GenerateTensorValue(GeneratorTensor_0{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
//out_nkhw.GenerateTensorValue(GeneratorTensor_Checkboard{}, num_thread);
//out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_4{}, num_thread);
#elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_0{}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
......@@ -798,7 +802,7 @@ int main(int argc, char* argv[])
}
#if 0
//LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw: ", out_nkhw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "in_nchw_host : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "in_nchw_device: ", in_nchw_device.mData, ",") << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment