Commit 067e23e9 authored by Jing Zhang's avatar Jing Zhang
Browse files

static_if input

parent 1f0bc665
...@@ -19,6 +19,139 @@ struct OutGlobalDescType ...@@ -19,6 +19,139 @@ struct OutGlobalDescType
typename std::conditional<isForw, OutGlobalDesc, InGlobalDesc>::type Type; typename std::conditional<isForw, OutGlobalDesc, InGlobalDesc>::type Type;
}; };
template <bool isForw,
class InType,
index_t N1,
index_t N2,
index_t Ho,
index_t Wo,
class Strides,
class Dilations>
struct GetInGlobalMergeDesc;
template <class InType,
index_t N1,
index_t N2,
index_t Ho,
index_t Wo,
class Strides,
class Dilations>
struct GetInGlobalMergeDesc<true, InType, N1, N2, Ho, Wo, Strides, Dilations>
{
__host__ __device__ constexpr auto GetDesc()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto in_n_c_h_w_global_desc = InType{};
// constexpr auto out_n_k_h_w_global_desc = OutType{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
// constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
// constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
constexpr index_t N0 = N / (N1 * N2);
constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{};
constexpr auto in_strides_new =
Sequence<in_n0_n1_n2_h_w_global_desc.GetStride(I0),
in_n0_n1_n2_h_w_global_desc.GetStride(I1),
in_n0_n1_n2_h_w_global_desc.GetStride(I2),
in_n0_n1_n2_h_w_global_desc.GetStride(I3) * Strides{}.Get(I0),
in_n0_n1_n2_h_w_global_desc.GetStride(I4) * Strides{}.Get(I1)>{};
constexpr auto in_n0_n1_n2_h_w_new_global_desc =
make_ConstantTensorDescriptor(in_lengths_new, in_strides_new);
constexpr auto in_c_1_1_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<1>{})
.Slice(I3, Number<1>{})
.Extract(Sequence<1, 2, 3>{});
constexpr auto in_win_lengths_new = Sequence<in_c_1_1_global_desc.GetLength(I0),
in_c_1_1_global_desc.GetLength(I1),
in_c_1_1_global_desc.GetLength(I2)>{};
constexpr auto in_win_strides_new =
Sequence<in_c_1_1_global_desc.GetStride(I0),
in_c_1_1_global_desc.GetStride(I1) * Dilations{}.Get(I0),
in_c_1_1_global_desc.GetStride(I2) * Dilations{}.Get(I1)>{};
constexpr auto in_c_1_1_new_global_desc =
make_ConstantTensorDescriptor(in_win_lengths_new, in_win_strides_new);
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
in_c_1_1_new_global_desc.Embed(in_n0_n1_n2_h_w_new_global_desc),
Sequence<0, 1, 2>{},
Sequence<4>{},
Sequence<3, 6, 7>{},
Sequence<5>{});
return in_e_n1_b_n2_global_merged_desc;
}
};
template <class InType,
index_t N1,
index_t N2,
index_t Ho,
index_t Wo,
class Strides,
class Dilations>
struct GetInGlobalMergeDesc<false, InType, N1, N2, Ho, Wo, Strides, Dilations>
{
__host__ __device__ constexpr auto GetDesc()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_h_w_global_desc = InType{};
// constexpr auto out_n_k_h_w_global_desc = OutType{};
// constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
// constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc
.Slice(I2, Number<mod_conv::integer_divide_ceil(Ho, Strides::Get(I0))>{})
.Slice(I3, Number<mod_conv::integer_divide_ceil(Wo, Strides::Get(I1))>{})
.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
constexpr auto in_n0_n1_n2_h_w_new_global_desc = in_n0_n1_n2_h_w_global_desc;
constexpr auto in_c_1_1_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<1>{})
.Slice(I3, Number<1>{})
.Extract(Sequence<1, 2, 3>{});
constexpr auto in_c_1_1_new_global_desc = in_c_1_1_global_desc;
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
in_c_1_1_new_global_desc.Embed(in_n0_n1_n2_h_w_new_global_desc),
Sequence<0, 1, 2>{},
Sequence<4>{},
Sequence<3, 6, 7>{},
Sequence<5>{});
return in_e_n1_b_n2_global_merged_desc;
}
};
#define FORW 0 #define FORW 0
// define B = merge(N0, Ho, Wo) // define B = merge(N0, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
...@@ -134,56 +267,17 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -134,56 +267,17 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr auto in_c_1_1_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<1>{}) constexpr auto in_c_1_1_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<1>{})
.Slice(I3, Number<1>{}) .Slice(I3, Number<1>{})
.Extract(Sequence<1, 2, 3>{}); .Extract(Sequence<1, 2, 3>{});
// input tensor constexpr bool fwd = Direction == 1;
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo] constexpr auto in_e_n1_b_n2_global_merged_desc =
#if FORW GetInGlobalMergeDesc<fwd,
constexpr auto in_n0_n1_n2_h_w_global_desc = decltype(in_n_c_h_w_global_desc),
in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{}) N1,
.Extract(Sequence<0, 1, 2, 4, 5>{}); N2,
constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{}; Ho,
Wo,
constexpr auto in_strides_new = Strides,
Sequence<in_n0_n1_n2_h_w_global_desc.GetStride(I0), Dilations>{}
in_n0_n1_n2_h_w_global_desc.GetStride(I1), .GetDesc();
in_n0_n1_n2_h_w_global_desc.GetStride(I2),
in_n0_n1_n2_h_w_global_desc.GetStride(I3) * Strides{}.Get(I0),
in_n0_n1_n2_h_w_global_desc.GetStride(I4) * Strides{}.Get(I1)>{};
constexpr auto in_n0_n1_n2_h_w_new_global_desc =
make_ConstantTensorDescriptor(in_lengths_new, in_strides_new);
constexpr auto in_win_lengths_new = Sequence<in_c_1_1_global_desc.GetLength(I0),
in_c_1_1_global_desc.GetLength(I1),
in_c_1_1_global_desc.GetLength(I2)>{};
constexpr auto in_win_strides_new =
Sequence<in_c_1_1_global_desc.GetStride(I0),
in_c_1_1_global_desc.GetStride(I1) * Dilations{}.Get(I0),
in_c_1_1_global_desc.GetStride(I2) * Dilations{}.Get(I1)>{};
constexpr auto in_c_1_1_new_global_desc =
make_ConstantTensorDescriptor(in_win_lengths_new, in_win_strides_new);
#else
constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc
.Slice(I2, Number<mod_conv::integer_divide_ceil(Ho, Strides::Get(I0))>{})
.Slice(I3, Number<mod_conv::integer_divide_ceil(Wo, Strides::Get(I1))>{})
.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
constexpr auto in_n0_n1_n2_h_w_new_global_desc = in_n0_n1_n2_h_w_global_desc;
constexpr auto in_c_1_1_new_global_desc = in_c_1_1_global_desc;
#endif
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
in_c_1_1_new_global_desc.Embed(in_n0_n1_n2_h_w_new_global_desc),
Sequence<0, 1, 2>{},
Sequence<4>{},
Sequence<3, 6, 7>{},
Sequence<5>{});
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy // memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment