"test/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "26941fa3777618dfc659e638e524b65f22dd32a6"
Commit e38ee30a authored by Chao Liu's avatar Chao Liu
Browse files

tweaking

parent 91e0de2e
...@@ -126,7 +126,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -126,7 +126,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM: atomic add // GEMM
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
? InMemoryDataOperation::none
: InMemoryDataOperation::atomic_add;
constexpr auto gridwise_gemm = constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1<GridSize, GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize, BlockSize,
...@@ -135,7 +139,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -135,7 +139,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
decltype(wei_k_e_global_desc), decltype(wei_k_e_global_desc),
decltype(out_k_b_global_desc), decltype(out_k_b_global_desc),
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
InMemoryDataOperation::atomic_add, in_memory_op,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
......
...@@ -352,8 +352,16 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -352,8 +352,16 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
} }
} }
// input: register to global memory, atomic add
{ {
#if 1 // debug
// input: register to global memory, atomic add
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
? InMemoryDataOperation::none
: InMemoryDataOperation::atomic_add;
#else
constexpr auto in_memory_op = InMemoryDataOperation::atomic_add;
#endif
constexpr index_t E1 = GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr index_t E1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t E0 = E / E1; constexpr index_t E0 = E / E1;
...@@ -426,13 +434,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -426,13 +434,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
InThreadCopyDstDataPerWrite_B, InThreadCopyDstDataPerWrite_B,
AddressSpace::vgpr, AddressSpace::vgpr,
AddressSpace::global, AddressSpace::global,
InMemoryDataOperation::atomic_add>({0, 0, 0, 0, 0, 0}, in_memory_op>({0, 0, 0, 0, 0, 0},
{e_thread_data_on_global / E1, {e_thread_data_on_global / E1,
e_thread_data_on_global % E1, e_thread_data_on_global % E1,
0, 0,
b_thread_data_on_global / B1, b_thread_data_on_global / B1,
b_thread_data_on_global % B1, b_thread_data_on_global % B1,
0}) 0})
.Run(p_in_thread, p_in_global); .Run(p_in_thread, p_in_global);
} }
} }
......
...@@ -23,10 +23,10 @@ int main(int argc, char* argv[]) ...@@ -23,10 +23,10 @@ int main(int argc, char* argv[])
{ {
using namespace launcher; using namespace launcher;
#if 1 #if 0
// 3x3 filter, 2x2 stride, 35x35 input // 3x3 filter, 2x2 stride, 35x35 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 1024;
constexpr index_t HI = 35; constexpr index_t HI = 35;
constexpr index_t WI = 35; constexpr index_t WI = 35;
constexpr index_t K = 1024; constexpr index_t K = 1024;
...@@ -59,7 +59,7 @@ int main(int argc, char* argv[]) ...@@ -59,7 +59,7 @@ int main(int argc, char* argv[])
constexpr index_t C = 1024; constexpr index_t C = 1024;
constexpr index_t HI = 28; constexpr index_t HI = 28;
constexpr index_t WI = 28; constexpr index_t WI = 28;
constexpr index_t K = 128; constexpr index_t K = 1024;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -83,13 +83,13 @@ int main(int argc, char* argv[]) ...@@ -83,13 +83,13 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 1x1 filter, 7x7 image // 1x1 filter, 7x7 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 1024;
constexpr index_t HI = 7; constexpr index_t HI = 7;
constexpr index_t WI = 7; constexpr index_t WI = 7;
constexpr index_t K = 128; constexpr index_t K = 1024;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -158,13 +158,13 @@ int main(int argc, char* argv[]) ...@@ -158,13 +158,13 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<2, 2>; using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>; using RightPads = Sequence<2, 2>;
#elif 1 #elif 0
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 1024;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 128; constexpr index_t K = 1024;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 7; constexpr index_t X = 7;
...@@ -246,7 +246,7 @@ int main(int argc, char* argv[]) ...@@ -246,7 +246,7 @@ int main(int argc, char* argv[])
#endif #endif
} }
#if 1 #if 0
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 1 #elif 1
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
...@@ -257,17 +257,17 @@ int main(int argc, char* argv[]) ...@@ -257,17 +257,17 @@ int main(int argc, char* argv[])
#else #else
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#endif #endif
(in_nchw_desc, (in_nchw_desc,
in_nchw_device, in_nchw_device,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
out_nkhw_desc, out_nkhw_desc,
out_nkhw, out_nkhw,
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
if(do_verification) if(do_verification)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment