Commit f05b210a authored by Jing Zhang's avatar Jing Zhang
Browse files

finished merge

parent b3108646
...@@ -499,7 +499,7 @@ int main(int argc, char* argv[]) ...@@ -499,7 +499,7 @@ int main(int argc, char* argv[])
constexpr index_t HDilation = 1; constexpr index_t HDilation = 1;
constexpr index_t WDilation = 1; constexpr index_t WDilation = 1;
constexpr index_t Direction = 2; // 1: Forward; 0:Backward constexpr index_t Direction = 1; // 1: Forward; 0:Backward
#if 0 #if 0
constexpr index_t N = 32; constexpr index_t N = 32;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -550,10 +550,10 @@ int main(int argc, char* argv[]) ...@@ -550,10 +550,10 @@ int main(int argc, char* argv[])
#elif 1 #elif 1
// 1x1 filter, 28x28 image // 1x1 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 512;
constexpr index_t HI = 56; constexpr index_t HI = 28;
constexpr index_t WI = 56; constexpr index_t WI = 28;
constexpr index_t K = 256; constexpr index_t K = 512;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -721,7 +721,7 @@ int main(int argc, char* argv[]) ...@@ -721,7 +721,7 @@ int main(int argc, char* argv[])
in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1 #elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_0{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 0 #elif 0
...@@ -734,49 +734,33 @@ int main(int argc, char* argv[]) ...@@ -734,49 +734,33 @@ int main(int argc, char* argv[])
#endif #endif
} }
#if 1 if(Direction == 1)
#if 0 {
device_direct_convolution_1
#elif 0
device_convolution_direct_v2_nchw_kcyx_nkhw
#elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 0
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
#elif 0
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
#elif 1
device_convolution_implicit_gemm_v4_nchw_kc1x1_nkhw
#endif
(in_nchw_desc,
in_nchw_device,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
strides,
dilations,
Number<Direction>{},
out_nkhw,
nrepeat);
#elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
lower_pads,
upper_pads,
nrepeat);
#endif
device_convolution_implicit_gemm_v4_nchw_kc1x1_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
strides,
dilations,
Number<Direction>{},
out_nkhw_device,
nrepeat);
}
else
{
device_convolution_implicit_gemm_v4_nchw_kc1x1_nkhw(in_nchw_desc,
in_nchw_device,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
strides,
dilations,
Number<Direction>{},
out_nkhw,
nrepeat);
}
if(do_verification) if(do_verification)
{ {
#if 0 #if 0
...@@ -800,10 +784,13 @@ int main(int argc, char* argv[]) ...@@ -800,10 +784,13 @@ int main(int argc, char* argv[])
} }
#if 0 #if 0
LogRange(std::cout << "out_nkhw: ", out_nkhw.mData, ",") << std::endl; //LogRange(std::cout << "out_nkhw: ", out_nkhw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl; //LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "in_nchw_host : ", in_nchw.mData, ",") << std::endl; //LogRange(std::cout << "in_nchw_host : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "in_nchw_device: ", in_nchw_device.mData, ",") << std::endl; //LogRange(std::cout << "in_nchw_device: ", in_nchw_device.mData, ",") << std::endl;
//LogRange(std::cout << "out_nkhw_host : ", out_nkhw.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
#endif #endif
} }
} }
...@@ -27,7 +27,7 @@ template <bool isForw, ...@@ -27,7 +27,7 @@ template <bool isForw,
index_t Wo, index_t Wo,
class Strides, class Strides,
class Dilations> class Dilations>
struct GetInGlobalMergeDesc; struct GetInGlobalFinalDesc;
template <class InType, template <class InType,
index_t N1, index_t N1,
...@@ -36,7 +36,7 @@ template <class InType, ...@@ -36,7 +36,7 @@ template <class InType,
index_t Wo, index_t Wo,
class Strides, class Strides,
class Dilations> class Dilations>
struct GetInGlobalMergeDesc<true, InType, N1, N2, Ho, Wo, Strides, Dilations> struct GetInGlobalFinalDesc<true, InType, N1, N2, Ho, Wo, Strides, Dilations>
{ {
__host__ __device__ constexpr auto GetDesc() __host__ __device__ constexpr auto GetDesc()
{ {
...@@ -102,7 +102,7 @@ template <class InType, ...@@ -102,7 +102,7 @@ template <class InType,
index_t Wo, index_t Wo,
class Strides, class Strides,
class Dilations> class Dilations>
struct GetInGlobalMergeDesc<false, InType, N1, N2, Ho, Wo, Strides, Dilations> struct GetInGlobalFinalDesc<false, InType, N1, N2, Ho, Wo, Strides, Dilations>
{ {
__host__ __device__ constexpr auto GetDesc() __host__ __device__ constexpr auto GetDesc()
{ {
...@@ -141,17 +141,17 @@ struct GetInGlobalMergeDesc<false, InType, N1, N2, Ho, Wo, Strides, Dilations> ...@@ -141,17 +141,17 @@ struct GetInGlobalMergeDesc<false, InType, N1, N2, Ho, Wo, Strides, Dilations>
}; };
template <bool isForw, class OutType, class Strides> template <bool isForw, class OutType, class Strides>
struct GetOutGlobalMergeDesc; struct GetOutGlobalFinalDesc;
template <class OutType, class Strides> template <class OutType, class Strides>
struct GetOutGlobalMergeDesc<true, OutType, Strides> struct GetOutGlobalFinalDesc<true, OutType, Strides>
{ {
__host__ __device__ constexpr auto GetDesc() { return OutType{}; } __host__ __device__ constexpr auto GetDesc() { return OutType{}; }
}; };
template <class OutType, class Strides> template <class OutType, class Strides>
struct GetOutGlobalMergeDesc<false, OutType, Strides> struct GetOutGlobalFinalDesc<false, OutType, Strides>
{ {
__host__ __device__ constexpr auto GetDesc() __host__ __device__ constexpr auto GetDesc()
{ {
...@@ -195,6 +195,24 @@ struct GetOutGlobalMergeDesc<false, OutType, Strides> ...@@ -195,6 +195,24 @@ struct GetOutGlobalMergeDesc<false, OutType, Strides>
} }
}; };
template <bool isForw, class WeiType>
struct GetWeiFinalDesc;
template <class WeiType>
struct GetWeiFinalDesc<true, WeiType>
{
__host__ __device__ constexpr auto GetDesc()
{
return WeiType{}.ReorderGivenNew2Old(Sequence<1, 0>{});
}
};
template <class WeiType>
struct GetWeiFinalDesc<false, WeiType>
{
__host__ __device__ constexpr auto GetDesc() { return WeiType{}; }
};
// define B = merge(N0, Ho, Wo) // define B = merge(N0, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
...@@ -311,7 +329,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kc1x1_nkhw ...@@ -311,7 +329,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kc1x1_nkhw
.Extract(Sequence<1, 2, 3>{}); .Extract(Sequence<1, 2, 3>{});
constexpr bool fwd = Direction == 1; constexpr bool fwd = Direction == 1;
constexpr auto in_e_n1_b_n2_global_merged_desc = constexpr auto in_e_n1_b_n2_global_merged_desc =
GetInGlobalMergeDesc<fwd, GetInGlobalFinalDesc<fwd,
decltype(in_n_c_h_w_global_desc), decltype(in_n_c_h_w_global_desc),
N1, N1,
N2, N2,
...@@ -352,7 +370,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kc1x1_nkhw ...@@ -352,7 +370,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kc1x1_nkhw
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc = wei_k_c_global_desc; constexpr auto wei_e_k_global_desc =
GetWeiFinalDesc<fwd, decltype(wei_k_c_global_desc)>{}.GetDesc();
// tensor descriptor in LDS, dst of blockwise copy // tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -576,7 +595,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kc1x1_nkhw ...@@ -576,7 +595,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kc1x1_nkhw
.Fold(I0, Number<N1>{}, Number<N2>{}); .Fold(I0, Number<N1>{}, Number<N2>{});
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc = constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc =
GetOutGlobalMergeDesc<fwd, GetOutGlobalFinalDesc<fwd,
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc), decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
Strides>{} Strides>{}
.GetDesc(); .GetDesc();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment