Commit fccf044a authored by Jing Zhang's avatar Jing Zhang
Browse files

static_if out

parent 067e23e9
...@@ -48,15 +48,8 @@ struct GetInGlobalMergeDesc<true, InType, N1, N2, Ho, Wo, Strides, Dilations> ...@@ -48,15 +48,8 @@ struct GetInGlobalMergeDesc<true, InType, N1, N2, Ho, Wo, Strides, Dilations>
constexpr auto in_n_c_h_w_global_desc = InType{}; constexpr auto in_n_c_h_w_global_desc = InType{};
// constexpr auto out_n_k_h_w_global_desc = OutType{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0); constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
// constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
// constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
constexpr index_t N0 = N / (N1 * N2); constexpr index_t N0 = N / (N1 * N2);
constexpr auto in_n0_n1_n2_h_w_global_desc = constexpr auto in_n0_n1_n2_h_w_global_desc =
...@@ -102,13 +95,13 @@ struct GetInGlobalMergeDesc<true, InType, N1, N2, Ho, Wo, Strides, Dilations> ...@@ -102,13 +95,13 @@ struct GetInGlobalMergeDesc<true, InType, N1, N2, Ho, Wo, Strides, Dilations>
} }
}; };
template <class InType, template <class InType,
index_t N1, index_t N1,
index_t N2, index_t N2,
index_t Ho, index_t Ho,
index_t Wo, index_t Wo,
class Strides, class Strides,
class Dilations> class Dilations>
struct GetInGlobalMergeDesc<false, InType, N1, N2, Ho, Wo, Strides, Dilations> struct GetInGlobalMergeDesc<false, InType, N1, N2, Ho, Wo, Strides, Dilations>
{ {
__host__ __device__ constexpr auto GetDesc() __host__ __device__ constexpr auto GetDesc()
...@@ -120,11 +113,6 @@ struct GetInGlobalMergeDesc<false, InType, N1, N2, Ho, Wo, Strides, Dilations> ...@@ -120,11 +113,6 @@ struct GetInGlobalMergeDesc<false, InType, N1, N2, Ho, Wo, Strides, Dilations>
constexpr auto in_n_c_h_w_global_desc = InType{}; constexpr auto in_n_c_h_w_global_desc = InType{};
// constexpr auto out_n_k_h_w_global_desc = OutType{};
// constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
// constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
constexpr auto in_n0_n1_n2_h_w_global_desc = constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc in_n_c_h_w_global_desc
.Slice(I2, Number<mod_conv::integer_divide_ceil(Ho, Strides::Get(I0))>{}) .Slice(I2, Number<mod_conv::integer_divide_ceil(Ho, Strides::Get(I0))>{})
...@@ -152,7 +140,61 @@ struct GetInGlobalMergeDesc<false, InType, N1, N2, Ho, Wo, Strides, Dilations> ...@@ -152,7 +140,61 @@ struct GetInGlobalMergeDesc<false, InType, N1, N2, Ho, Wo, Strides, Dilations>
} }
}; };
#define FORW 0 template <bool isForw, class OutType, class Strides>
struct GetOutGlobalMergeDesc;
template <class OutType, class Strides>
struct GetOutGlobalMergeDesc<true, OutType, Strides>
{
__host__ __device__ constexpr auto GetDesc() { return OutType{}; }
};
template <class OutType, class Strides>
struct GetOutGlobalMergeDesc<false, OutType, Strides>
{
__host__ __device__ constexpr auto GetDesc()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc = OutType{};
constexpr auto out_lengths_new = Sequence<
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I0),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I1),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I2),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I3),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I4),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I5),
mod_conv::integer_divide_ceil(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I6),
Strides{}.Get(I0)),
mod_conv::integer_divide_ceil(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I7),
Strides{}.Get(I1))>{};
constexpr auto out_strides_new =
Sequence<out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I0),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I1),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I2),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I3),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I4),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I5),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I6) * Strides{}.Get(I0),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I7) * Strides{}.Get(I1)>{};
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc =
make_ConstantTensorDescriptor(out_lengths_new, out_strides_new);
return out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc;
}
};
// define B = merge(N0, Ho, Wo) // define B = merge(N0, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
...@@ -530,36 +572,11 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -530,36 +572,11 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{}) out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
.Fold(I0, Number<N1>{}, Number<N2>{}); .Fold(I0, Number<N1>{}, Number<N2>{});
#if FORW
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc = constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc =
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc; GetOutGlobalMergeDesc<fwd,
#else decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
constexpr auto out_lengths_new = Sequence< Strides>{}
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I0), .GetDesc();
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I1),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I2),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I3),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I4),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I5),
mod_conv::integer_divide_ceil(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I6), Strides{}.Get(I0)),
mod_conv::integer_divide_ceil(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I7), Strides{}.Get(I1))>{};
constexpr auto out_strides_new = Sequence<
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I0),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I1),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I2),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I3),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I4),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I5),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I6) * Strides{}.Get(I0),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I7) * Strides{}.Get(I1)>{};
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc =
make_ConstantTensorDescriptor(out_lengths_new, out_strides_new);
#endif
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment