Commit 9750de73 authored by Chao Liu's avatar Chao Liu
Browse files

adding bwd data v3r1

parent ef2664fb
...@@ -205,14 +205,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -205,14 +205,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
#if 1 // debug // get a series of GEMMs
// GEMMs auto f_get_gemm = [&](auto ytilda_, auto xtilda_) {
static_for<0, Ytilda, 1>{}([&](auto ytilda_) {
static_for<0, Xtilda, 1>{}([&](auto xtilda_) {
#else
static_for<1, 2, 1>{}([&](auto ytilda_) {
static_for<1, 2, 1>{}([&](auto xtilda_) {
#endif
constexpr index_t ytilda = decltype(ytilda_){}; constexpr index_t ytilda = decltype(ytilda_){};
constexpr index_t xtilda = decltype(xtilda_){}; constexpr index_t xtilda = decltype(xtilda_){};
...@@ -231,10 +225,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -231,10 +225,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
Trim<Sequence<Ytilda, Xtilda>, Trim<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>, Sequence<ytilda, xtilda>,
Sequence<Ytilda - ytilda - 1, Xtilda - xtilda - 1>>{}), Sequence<Ytilda - ytilda - 1, Xtilda - xtilda - 1>>{}),
make_tuple( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_ydotnonzero_1_xdotnonzero_1_global_desc, wei_k_c_ydotnonzero_1_xdotnonzero_1_global_desc,
...@@ -340,7 +332,28 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -340,7 +332,28 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
3, 3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{}; GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global); return gridwise_gemm;
};
// GEMMs
index_t shared_mem_size = 0;
static_for<0, Ytilda, 1>{}([&](auto ytilda) {
static_for<0, Xtilda, 1>{}([&](auto xtilda) {
auto gemm = f_get_gemm(ytilda, xtilda);
shared_mem_size = math::max(shared_mem_size, gemm.GetSharedMemorySize());
});
});
__shared__ Float p_shared_float[shared_mem_size / sizeof(Float)];
// GEMMs
static_for<0, Ytilda, 1>{}([&](auto ytilda) {
static_for<0, Xtilda, 1>{}([&](auto xtilda) {
auto gemm = f_get_gemm(ytilda, xtilda);
gemm.Run(p_wei_global, p_in_global, p_out_global, p_shared_float);
}); });
}); });
} }
......
...@@ -50,9 +50,37 @@ template <index_t GridSize, ...@@ -50,9 +50,37 @@ template <index_t GridSize,
index_t CThreadCopyDstDataPerWrite> index_t CThreadCopyDstDataPerWrite>
struct GridwiseGemmTransposedANormalBNormalC_v1 struct GridwiseGemmTransposedANormalBNormalC_v1
{ {
__host__ __device__ static constexpr index_t GetSharedMemorySize()
{
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N,
ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
return 2 * (a_block_space + b_block_space) * sizeof(Float);
}
__device__ void Run(const Float* __restrict__ p_a_global, __device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global) const Float* __restrict__ p_c_global,
void* __restrict__ p_shared) const
{ {
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
...@@ -184,8 +212,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -184,8 +212,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
constexpr index_t b_block_space = constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align); math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
__shared__ Float p_a_block_double[2 * a_block_space]; Float* p_a_block_double = reinterpret_cast<Float*>(p_shared);
__shared__ Float p_b_block_double[2 * b_block_space]; Float* p_b_block_double = p_a_block_double + 2 * a_block_space;
// register allocation for output // register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()]; AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
...@@ -329,6 +357,17 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -329,6 +357,17 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
.Run(p_c_thread, p_c_global); .Run(p_c_thread, p_c_global);
} }
} }
__device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global) const
{
constexpr index_t shared_mem_size = GetSharedMemorySize();
__shared__ Float p_shared_float[shared_mem_size / sizeof(Float)];
Run(p_a_global, p_b_global, p_c_global, p_shared_float);
}
}; };
} // namespace ck } // namespace ck
......
...@@ -187,7 +187,7 @@ int main(int argc, char* argv[]) ...@@ -187,7 +187,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>; using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>;
#elif 1 #elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -249,7 +249,7 @@ int main(int argc, char* argv[]) ...@@ -249,7 +249,7 @@ int main(int argc, char* argv[])
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 0 #elif 0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif 1 #elif 0
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
#else #else
device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw
......
...@@ -296,7 +296,7 @@ int main(int argc, char* argv[]) ...@@ -296,7 +296,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81% // cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128; constexpr index_t N = 128;
...@@ -327,7 +327,7 @@ int main(int argc, char* argv[]) ...@@ -327,7 +327,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<2, 2>; using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>; using RightPads = Sequence<2, 2>;
#elif 0 #elif 1
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment