Commit bb37eb69 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge forw and back

parent 81ec25b4
......@@ -56,8 +56,8 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr index_t N1 = 2;
constexpr index_t N2 = 4;
//constexpr index_t B = N * mod_conv::integer_divide_ceil(Ho, Strides::Get(I0)) *
//mod_conv::integer_divide_ceil(Wo, Strides::Get(I1)) / (N1 * N2);
// constexpr index_t B = N * mod_conv::integer_divide_ceil(Ho, Strides::Get(I0)) *
// mod_conv::integer_divide_ceil(Wo, Strides::Get(I1)) / (N1 * N2);
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
#if 1
......@@ -113,6 +113,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
BlockSize,
Strides,
Dilations,
Direction,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
......
......@@ -493,13 +493,13 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main(int argc, char* argv[])
{
constexpr index_t HStride = 2;
constexpr index_t WStride = 2;
constexpr index_t HStride = 1;
constexpr index_t WStride = 1;
constexpr index_t HDilation = 1;
constexpr index_t WDilation = 1;
constexpr index_t Direction = 2; // 1: Forward; 2:Backward
constexpr index_t Direction = 1; // 1: Forward; 2:Backward
#if 0
constexpr index_t N = 32;
constexpr index_t C = 128;
......
......@@ -7,11 +7,14 @@
#include "blockwise_gemm.hip.hpp"
#include "threadwise_generic_tensor_slice_op.hip.hpp"
#define FORW 1
// define B = merge(N0, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class Strides,
class Dilations,
index_t Direction,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
......@@ -50,8 +53,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_conv_out_global) const
{
auto p_in_global = p_conv_out_global;
auto p_out_global = p_conv_in_global;
auto p_in_global = Direction == 1 ? p_conv_in_global : p_conv_out_global;
auto p_out_global = Direction == 1 ? p_conv_out_global : p_conv_in_global;
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert(N2 == GemmNPerThreadSubC, "wrong!");
......@@ -71,10 +74,16 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = OutGlobalDesc{};
#if FORW
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
#else
constexpr auto in_n_c_h_w_global_desc = OutGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = InGlobalDesc{};
#endif
// to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek
constexpr auto wei_k_c_1_1_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = InGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
......@@ -112,34 +121,62 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
#if FORW
constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
#else
constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc
.Slice(I2, Number<mod_conv::integer_divide_ceil(Ho, Strides::Get(I0))>{})
.Slice(I3, Number<mod_conv::integer_divide_ceil(Wo, Strides::Get(I1))>{})
.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
#endif
// constexpr auto in_n0_n1_n2_h_w_global_desc =
// in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{})
//.Extract(Sequence<0, 1, 2, 4, 5>{});
#if FORW
constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{};
constexpr auto in_strides_new =
Sequence<in_n0_n1_n2_h_w_global_desc.GetStride(I0),
in_n0_n1_n2_h_w_global_desc.GetStride(I1),
in_n0_n1_n2_h_w_global_desc.GetStride(I2),
in_n0_n1_n2_h_w_global_desc.GetStride(I3) * Strides{}.Get(I0),
in_n0_n1_n2_h_w_global_desc.GetStride(I4) * Strides{}.Get(I1)>{};
// constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{};
constexpr auto in_n0_n1_n2_h_w_new_global_desc =
make_ConstantTensorDescriptor(in_lengths_new, in_strides_new);
#else
constexpr auto in_n0_n1_n2_h_w_new_global_desc = in_n0_n1_n2_h_w_global_desc;
#endif
// batch descritpor for device memory
// to-do: add dilation: keep lengths, modify strides
constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{})
.Slice(I3, Number<X>{})
.Extract(Sequence<1, 2, 3>{});
#if FORW
constexpr auto in_win_lengths_new = Sequence<in_c_y_x_global_desc.GetLength(I0),
in_c_y_x_global_desc.GetLength(I1),
in_c_y_x_global_desc.GetLength(I2)>{};
constexpr auto in_win_strides_new =
Sequence<in_c_y_x_global_desc.GetStride(I0),
in_c_y_x_global_desc.GetStride(I1) * Dilations{}.Get(I0),
in_c_y_x_global_desc.GetStride(I2) * Dilations{}.Get(I1)>{};
constexpr auto in_c_y_x_new_global_desc =
make_ConstantTensorDescriptor(in_win_lengths_new, in_win_strides_new);
#else
constexpr auto in_c_y_x_new_global_desc = in_c_y_x_global_desc;
#endif
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_new_global_desc),
in_c_y_x_new_global_desc.Embed(in_n0_n1_n2_h_w_new_global_desc),
Sequence<0, 1, 2>{},
Sequence<4>{},
Sequence<3, 6, 7>{},
......@@ -174,8 +211,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
InBlockCopyDstDataPerWrite_N2>(
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc = wei_k_c_1_1_global_desc.Unfold(I1, I3);
// tensor descriptor in LDS, dst of blockwise copy
......@@ -396,6 +433,10 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
.Fold(I0, Number<N1>{}, Number<N2>{});
#if FORW
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc =
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc;
#else
constexpr auto out_lengths_new = Sequence<
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I0),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I1),
......@@ -420,6 +461,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc =
make_ConstantTensorDescriptor(out_lengths_new, out_strides_new);
#endif
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment