Commit 1f0bc665 authored by Jing Zhang's avatar Jing Zhang
Browse files

type

parent 8e51f990
...@@ -499,7 +499,7 @@ int main(int argc, char* argv[]) ...@@ -499,7 +499,7 @@ int main(int argc, char* argv[])
constexpr index_t HDilation = 1; constexpr index_t HDilation = 1;
constexpr index_t WDilation = 1; constexpr index_t WDilation = 1;
constexpr index_t Direction = 2; // 1: Forward; 0:Backward constexpr index_t Direction = 0; // 1: Forward; 0:Backward
#if 0 #if 0
constexpr index_t N = 32; constexpr index_t N = 32;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -551,8 +551,8 @@ int main(int argc, char* argv[]) ...@@ -551,8 +551,8 @@ int main(int argc, char* argv[])
// 1x1 filter, 28x28 image // 1x1 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
constexpr index_t HI = 13; constexpr index_t HI = 7;
constexpr index_t WI = 13; constexpr index_t WI = 7;
constexpr index_t K = 128; constexpr index_t K = 128;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
......
...@@ -7,6 +7,18 @@ ...@@ -7,6 +7,18 @@
#include "blockwise_gemm.hip.hpp" #include "blockwise_gemm.hip.hpp"
#include "threadwise_generic_tensor_slice_op.hip.hpp" #include "threadwise_generic_tensor_slice_op.hip.hpp"
template <bool isForw, class InGlobalDesc, class OutGlobalDesc>
struct InGlobalDescType
{
typename std::conditional<isForw, InGlobalDesc, OutGlobalDesc>::type Type;
};
template <bool isForw, class InGlobalDesc, class OutGlobalDesc>
struct OutGlobalDescType
{
typename std::conditional<isForw, OutGlobalDesc, InGlobalDesc>::type Type;
};
#define FORW 0 #define FORW 0
// define B = merge(N0, Ho, Wo) // define B = merge(N0, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
...@@ -73,13 +85,11 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -73,13 +85,11 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
#if FORW constexpr auto in_n_c_h_w_global_desc =
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; InGlobalDescType<Direction == 1, InGlobalDesc, OutGlobalDesc>{}.Type;
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; constexpr auto out_n_k_h_w_global_desc =
#else OutGlobalDescType<Direction == 1, InGlobalDesc, OutGlobalDesc>{}.Type;
constexpr auto in_n_c_h_w_global_desc = OutGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = InGlobalDesc{};
#endif
// to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek // to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek
constexpr auto wei_k_c_1_1_global_desc = WeiGlobalDesc{}; constexpr auto wei_k_c_1_1_global_desc = WeiGlobalDesc{};
...@@ -92,8 +102,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -92,8 +102,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2); constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3); constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_1_1_global_desc.GetLength(I2); // constexpr index_t Y = wei_k_c_1_1_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_1_1_global_desc.GetLength(I3); // constexpr index_t X = wei_k_c_1_1_global_desc.GetLength(I3);
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread"); static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
...@@ -101,7 +111,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -101,7 +111,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr index_t B = N0 * Ho * Wo; constexpr index_t B = N0 * Ho * Wo;
constexpr index_t E = C * Y * X; constexpr index_t E = C;
// divide block work by [K, B] // divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0, static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
...@@ -119,22 +129,17 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -119,22 +129,17 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock; const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock; const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// batch descritpor for device memory
// to-do: add dilation: keep lengths, modify strides
constexpr auto in_c_1_1_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<1>{})
.Slice(I3, Number<1>{})
.Extract(Sequence<1, 2, 3>{});
// input tensor // input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo] // tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
#if FORW #if FORW
constexpr auto in_n0_n1_n2_h_w_global_desc = constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{}) in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{}); .Extract(Sequence<0, 1, 2, 4, 5>{});
#else
constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc
.Slice(I2, Number<mod_conv::integer_divide_ceil(Ho, Strides::Get(I0))>{})
.Slice(I3, Number<mod_conv::integer_divide_ceil(Wo, Strides::Get(I1))>{})
.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
#endif
#if FORW
constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{}; constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{};
constexpr auto in_strides_new = constexpr auto in_strides_new =
...@@ -147,34 +152,34 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -147,34 +152,34 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr auto in_n0_n1_n2_h_w_new_global_desc = constexpr auto in_n0_n1_n2_h_w_new_global_desc =
make_ConstantTensorDescriptor(in_lengths_new, in_strides_new); make_ConstantTensorDescriptor(in_lengths_new, in_strides_new);
#else constexpr auto in_win_lengths_new = Sequence<in_c_1_1_global_desc.GetLength(I0),
in_c_1_1_global_desc.GetLength(I1),
constexpr auto in_n0_n1_n2_h_w_new_global_desc = in_n0_n1_n2_h_w_global_desc; in_c_1_1_global_desc.GetLength(I2)>{};
#endif
// batch descritpor for device memory
// to-do: add dilation: keep lengths, modify strides
constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{})
.Slice(I3, Number<X>{})
.Extract(Sequence<1, 2, 3>{});
#if FORW
constexpr auto in_win_lengths_new = Sequence<in_c_y_x_global_desc.GetLength(I0),
in_c_y_x_global_desc.GetLength(I1),
in_c_y_x_global_desc.GetLength(I2)>{};
constexpr auto in_win_strides_new = constexpr auto in_win_strides_new =
Sequence<in_c_y_x_global_desc.GetStride(I0), Sequence<in_c_1_1_global_desc.GetStride(I0),
in_c_y_x_global_desc.GetStride(I1) * Dilations{}.Get(I0), in_c_1_1_global_desc.GetStride(I1) * Dilations{}.Get(I0),
in_c_y_x_global_desc.GetStride(I2) * Dilations{}.Get(I1)>{}; in_c_1_1_global_desc.GetStride(I2) * Dilations{}.Get(I1)>{};
constexpr auto in_c_y_x_new_global_desc = constexpr auto in_c_1_1_new_global_desc =
make_ConstantTensorDescriptor(in_win_lengths_new, in_win_strides_new); make_ConstantTensorDescriptor(in_win_lengths_new, in_win_strides_new);
#else #else
constexpr auto in_c_y_x_new_global_desc = in_c_y_x_global_desc; constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc
.Slice(I2, Number<mod_conv::integer_divide_ceil(Ho, Strides::Get(I0))>{})
.Slice(I3, Number<mod_conv::integer_divide_ceil(Wo, Strides::Get(I1))>{})
.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
constexpr auto in_n0_n1_n2_h_w_new_global_desc = in_n0_n1_n2_h_w_global_desc;
constexpr auto in_c_1_1_new_global_desc = in_c_1_1_global_desc;
#endif #endif
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
in_c_y_x_new_global_desc.Embed(in_n0_n1_n2_h_w_new_global_desc), in_c_1_1_new_global_desc.Embed(in_n0_n1_n2_h_w_new_global_desc),
Sequence<0, 1, 2>{}, Sequence<0, 1, 2>{},
Sequence<4>{}, Sequence<4>{},
Sequence<3, 6, 7>{}, Sequence<3, 6, 7>{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment