"vscode:/vscode.git/clone" did not exist on "cee98c903bbeccfffe5636c6fdfb4805edcaa1fc"
Commit d4878d99 authored by Chao Liu's avatar Chao Liu
Browse files

initial padding support for nchw

parent bd7a2300
...@@ -51,6 +51,7 @@ template <index_t GridSize, ...@@ -51,6 +51,7 @@ template <index_t GridSize,
index_t WeiBlockCopyDstDataPerWrite_K> index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
{ {
#if 1
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
...@@ -69,20 +70,24 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded ...@@ -69,20 +70,24 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto I5 = Number<5>{};
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; constexpr auto in_n_c_hi_wi_global_desc =
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; constexpr auto wei_k_c_y_x_global_desc =
make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides());
constexpr auto out_n_k_ho_wo_global_desc =
make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides());
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0); constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1); constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2); constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3); constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
...@@ -126,30 +131,35 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded ...@@ -126,30 +131,35 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock; const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor // input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo] // global memory
constexpr auto in_n0_n1_n2_h_w_global_desc = constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrideH>{}) in_n_c_hi_wi_global_desc,
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrideW>{}) make_tuple(
.Fold(I0, Number<N1>{}, Number<N2>{}) PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
.Extract(Sequence<0, 1, 2, 4, 5>{}); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
// batch descritpor for device memory
constexpr auto in_c_y_x_global_desc = constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{}) in_n_c_hip_wip_global_desc,
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{}) make_tuple(Unmerge<Sequence<N0, N1, N2>>{},
.Extract(Sequence<1, 2, 3>{}); PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc), make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
Sequence<0, 1, 2>{},
Sequence<4>{}, constexpr auto in_e_n1_b_n2_global_desc = transform_tensor_descriptor(
Sequence<3, 6, 7>{}, in_n0_n1_n2_c_y_ho_x_wo_global_desc,
Sequence<5>{}); make_tuple(Merge<Sequence<C, Y, X>>{},
PassThrough<N1>{},
Merge<Sequence<N0, Ho, Wo>>{},
PassThrough<N2>{}),
make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy // memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_e_n1_b_n2_block_desc = make_native_tensor_descriptor_aligned(
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{}); Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
// this check is ad-hoc // this check is ad-hoc
...@@ -162,8 +172,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded ...@@ -162,8 +172,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize, BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(in_e_n1_b_n2_global_merged_desc), decltype(in_e_n1_b_n2_global_desc),
decltype(in_e_n1_b_n2_block_desc), decltype(in_e_n1_b_n2_block_desc),
decltype(in_e_n1_b_n2_block_desc.GetLengths()), decltype(in_e_n1_b_n2_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_B_N2, InBlockCopySubLengths_E_N1_B_N2,
...@@ -180,11 +190,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded ...@@ -180,11 +190,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc = constexpr auto wei_e_k_global_desc =
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{}); transform_tensor_descriptor(wei_k_c_y_x_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, PassThrough<K>{}),
make_tuple(Sequence<1, 2, 3>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// tensor descriptor in LDS, dst of blockwise copy // tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
Sequence<EPerBlock, KPerBlock>{}, Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{}); Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
...@@ -192,7 +205,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded ...@@ -192,7 +205,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize, BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
...@@ -215,8 +228,11 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded ...@@ -215,8 +228,11 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
// register // register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc); constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_n1bn2_block_mtx_desc = constexpr auto b_e_n1bn2_block_mtx_desc = make_ConstantMatrixDescriptor(
make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3)); in_e_n1_b_n2_block_desc.GetLength(I0),
in_e_n1_b_n2_block_desc.GetLength(I1) * in_e_n1_b_n2_block_desc.GetLength(I2) *
in_e_n1_b_n2_block_desc.GetLength(I3),
in_e_n1_b_n2_block_desc.GetStride(I0));
// sanity check // sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
...@@ -288,21 +304,28 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded ...@@ -288,21 +304,28 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
static_assert(K % (K1 * K2) == 0, "wrong!");
// define tensor descriptor for threadwise copy // define tensor descriptor for threadwise copy
// output memory layout descriptor in register // output memory layout descriptor in register
constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc = constexpr auto out_k0_k1_k2_n1_n0_ho_wo_n2_thread_desc =
make_ConstantTensorDescriptor_packed( make_native_tensor_descriptor_packed(
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{}); Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
// output tensor descriptor in register, src of threadwise copy // output tensor descriptor in register, src of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc = constexpr auto out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc =
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old( reorder_tensor_descriptor_given_upper2lower(out_k0_k1_k2_n1_n0_ho_wo_n2_thread_desc,
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{}); Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
// output memory layout descriptor in device memory, dst of threadwise copy // output memory layout descriptor in device memory, dst of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc = constexpr auto out_n0_n1_n2_k0_k1_k2_ho_wo_global_desc = transform_tensor_descriptor(
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{}) out_n_k_ho_wo_global_desc,
.Fold(I0, Number<N1>{}, Number<N2>{}); make_tuple(Unmerge<Sequence<N / (N1 * N2), N1, N2>>{},
Unmerge<Sequence<K / (K1 * K2), K1, K2>>{},
PassThrough<Ho>{},
PassThrough<Wo>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}, Sequence<6>{}, Sequence<7>{}));
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
...@@ -317,32 +340,159 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded ...@@ -317,32 +340,159 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
// output merged global tensor descriptor, for calculating origin of thread tensor // output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory // in global memory
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( constexpr auto out_n0_n1_n2_k_ho_wo_global_desc = transform_tensor_descriptor(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5), out_n_k_ho_wo_global_desc,
Sequence<3>{}, make_tuple(Unmerge<Sequence<N / (N1 * N2), N1, N2>>{},
Sequence<1>{}, PassThrough<K>{},
Sequence<0, 4, 5>{}, PassThrough<Ho>{},
Sequence<2>{}); PassThrough<Wo>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}));
constexpr auto out_k_n1_b_n2_global_desc = transform_tensor_descriptor(
out_n0_n1_n2_k_ho_wo_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<N1>{},
Merge<Sequence<N0, Ho, Wo>>{},
PassThrough<N2>{}),
make_tuple(Sequence<3>{}, Sequence<1>{}, Sequence<0, 4, 5>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// origin of dst in device memory // origin of dst in device memory
Float* p_out_thread_on_global = Float* p_out_thread_on_global =
p_out_global + p_out_global +
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( out_k_n1_b_n2_global_desc.CalculateOffset(
k_thread_data_on_global, 0, b_thread_data_on_global, 0); {k_thread_data_on_global, 0, b_thread_data_on_global, 0});
ThreadwiseGenericTensorSliceCopy_v2r1< ThreadwiseGenericTensorSliceCopy_v4r2<
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc), decltype(out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc), decltype(out_n0_n1_n2_k0_k1_k2_ho_wo_global_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()), decltype(out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 8, 1>::type, arithmetic_sequence_gen<0, 8, 1>::type,
arithmetic_sequence_gen<0, 8, 1>::type,
7,
7, 7,
1, 1,
1>({0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0}) 1>({0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0})
.Run(p_out_thread, p_out_thread_on_global); .Run(p_out_thread, p_out_thread_on_global);
} }
} }
#else
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC;
static_assert((N1 * N2 * BPerBlock) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
0,
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I5 = Number<5>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc =
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
constexpr auto wei_k_c_y_x_global_desc =
make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides());
constexpr auto out_n_k_h_w_global_desc =
make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides());
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
constexpr index_t Hi = in_n_c_h_w_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_h_w_global_desc.GetLength(I3);
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
constexpr index_t N0 = N / (N1 * N2);
constexpr index_t B = N0 * Ho * Wo;
constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load
static_assert(ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1,
"wrong! global vector load of input tensor is wrong");
static_assert((X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
constexpr auto in_n_c_hi_wi_global_desc =
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(Unmerge<Sequence<N0, N1, N2>>{},
PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
constexpr auto in_e_n1_b_n2_global_desc = transform_tensor_descriptor(
in_n0_n1_n2_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{},
PassThrough<N1>{},
Merge<Sequence<N0, Ho, Wo>>{},
PassThrough<N2>{}),
make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_tensor_descriptor("in_n_c_hi_wi_global_desc: ", in_n_c_hi_wi_global_desc);
print_tensor_descriptor("in_n_c_hip_wip_global_desc: ", in_n_c_hip_wip_global_desc);
print_tensor_descriptor("in_n0_n1_n2_c_y_ho_x_wo_global_desc: ",
in_n0_n1_n2_c_y_ho_x_wo_global_desc);
print_tensor_descriptor("in_e_n1_b_n2_global_desc: ", in_e_n1_b_n2_global_desc);
auto coord3 = make_tensor_coordinate_v2(in_e_n1_b_n2_global_desc, {1, 1, 1, 1});
auto idx3 = coord3.GetIndex();
auto idx2 = coord3.GetLowerCoordinate().GetIndex();
auto idx1 = coord3.GetLowerCoordinate().GetLowerCoordinate().GetIndex();
auto idx0 =
coord3.GetLowerCoordinate().GetLowerCoordinate().GetLowerCoordinate().GetIndex();
print_array("idx3: ", idx3);
print_array("idx2: ", idx2);
print_array("idx1: ", idx1);
print_array("idx0: ", idx0);
}
#endif
p_out_global[0] = in_e_n1_b_n2_global_desc.CalculateOffset({0, 0, 10, 0});
}
#endif
}; };
} // namespace ck } // namespace ck
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor.hpp"
#include "tensor_descriptor.hpp"
namespace ck { namespace ck {
...@@ -52,7 +53,7 @@ __host__ __device__ constexpr auto ...@@ -52,7 +53,7 @@ __host__ __device__ constexpr auto
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{}; return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
} }
template <class... Ts> template <typename... Ts>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorDescriptor<Ts...>) __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorDescriptor<Ts...>)
{ {
using TDesc = ConstantTensorDescriptor<Ts...>; using TDesc = ConstantTensorDescriptor<Ts...>;
...@@ -63,7 +64,18 @@ __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorD ...@@ -63,7 +64,18 @@ __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorD
TDesc::GetStrides()[0]>{}; TDesc::GetStrides()[0]>{};
} }
template <class TDesc> template <typename... Ts>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(NativeTensorDescriptor<Ts...>)
{
using TDesc = NativeTensorDescriptor<Ts...>;
static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
static_assert(TDesc::GetStrides()[1] == 1, "wrong");
return ConstantMatrixDescriptor<TDesc::GetLengths()[0],
TDesc::GetLengths()[1],
TDesc::GetStrides()[0]>{};
}
template <typename TDesc>
__host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s) __host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s)
{ {
printf( printf(
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
template <class Lengths> template <class Lengths>
__host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths) __host__ __device__ constexpr auto calculate_tensor_strides_packed_old(Lengths)
{ {
return reverse_inclusive_scan_sequence( return reverse_inclusive_scan_sequence(
Lengths{}.PopFront(), math::multiplies<index_t>{}, Number<1>{}) Lengths{}.PopFront(), math::multiplies<index_t>{}, Number<1>{})
...@@ -14,12 +14,12 @@ __host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths) ...@@ -14,12 +14,12 @@ __host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths)
} }
template <class Lengths, index_t Align> template <class Lengths, index_t Align>
__host__ __device__ constexpr auto calculate_tensor_strides_aligned(Lengths, Number<Align>) __host__ __device__ constexpr auto calculate_tensor_strides_aligned_old(Lengths, Number<Align>)
{ {
constexpr index_t L_back_align = constexpr index_t L_back_align =
Align * math::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align); Align * math::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align);
return calculate_tensor_strides_packed( return calculate_tensor_strides_packed_old(
Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{})); Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{}));
} }
...@@ -187,7 +187,7 @@ struct ConstantTensorDescriptor ...@@ -187,7 +187,7 @@ struct ConstantTensorDescriptor
{ {
Array<index_t, nDim> multi_id; Array<index_t, nDim> multi_id;
using PackedStrides = decltype(calculate_tensor_strides_packed(GetLengths())); using PackedStrides = decltype(calculate_tensor_strides_packed_old(GetLengths()));
// calculate index in each of the dimensions in the order of their dimension // calculate index in each of the dimensions in the order of their dimension
static_for<0, nDim - 1, 1>{}(lambda_GetMultiIndexFrom1dIndex<PackedStrides>(id, multi_id)); static_for<0, nDim - 1, 1>{}(lambda_GetMultiIndexFrom1dIndex<PackedStrides>(id, multi_id));
...@@ -468,7 +468,7 @@ struct ConstantTensorDescriptor ...@@ -468,7 +468,7 @@ struct ConstantTensorDescriptor
__host__ __device__ static constexpr auto Pack() __host__ __device__ static constexpr auto Pack()
{ {
using packed_strides = decltype(calculate_tensor_strides_packed(Lengths{})); using packed_strides = decltype(calculate_tensor_strides_packed_old(Lengths{}));
return ConstantTensorDescriptor<Lengths, packed_strides>{}; return ConstantTensorDescriptor<Lengths, packed_strides>{};
} }
...@@ -490,7 +490,7 @@ struct ConstantTensorDescriptor ...@@ -490,7 +490,7 @@ struct ConstantTensorDescriptor
template <class Lengths> template <class Lengths>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_packed(Lengths) __host__ __device__ constexpr auto make_ConstantTensorDescriptor_packed(Lengths)
{ {
using Strides = decltype(calculate_tensor_strides_packed(Lengths{})); using Strides = decltype(calculate_tensor_strides_packed_old(Lengths{}));
return ConstantTensorDescriptor<Lengths, Strides>{}; return ConstantTensorDescriptor<Lengths, Strides>{};
} }
...@@ -503,7 +503,7 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride ...@@ -503,7 +503,7 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride
template <class Lengths, index_t Align> template <class Lengths, index_t Align>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>) __host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
{ {
using Strides = decltype(calculate_tensor_strides_aligned(Lengths{}, Number<Align>{})); using Strides = decltype(calculate_tensor_strides_aligned_old(Lengths{}, Number<Align>{}));
return ConstantTensorDescriptor<Lengths, Strides>{}; return ConstantTensorDescriptor<Lengths, Strides>{};
} }
......
...@@ -24,8 +24,6 @@ struct PassThrough ...@@ -24,8 +24,6 @@ struct PassThrough
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<1>{}; } __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<1>{}; }
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<Length>{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<Length>{}; } __host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<Length>{}; }
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
...@@ -51,11 +49,11 @@ struct PassThrough ...@@ -51,11 +49,11 @@ struct PassThrough
} }
}; };
// LowLengths: Sequence<...> // LowerLengths: Sequence<...>
template <typename LowLengths, typename LeftPads, typename RightPads> template <typename LowerLengths, typename LeftPads, typename RightPads>
struct Pad struct Pad
{ {
static constexpr index_t nDim = LowLengths::Size(); static constexpr index_t nDim = LowerLengths::Size();
using LowerIndex = MultiIndex<nDim>; using LowerIndex = MultiIndex<nDim>;
using UpperIndex = MultiIndex<nDim>; using UpperIndex = MultiIndex<nDim>;
...@@ -64,11 +62,9 @@ struct Pad ...@@ -64,11 +62,9 @@ struct Pad
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; } __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
__host__ __device__ static constexpr auto GetUpperLengths() __host__ __device__ static constexpr auto GetUpperLengths()
{ {
return GetLowerLengths() + LeftPads{} + RightPads{}; return LowerLengths{} + LeftPads{} + RightPads{};
} }
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
...@@ -98,7 +94,7 @@ struct Pad ...@@ -98,7 +94,7 @@ struct Pad
// only check if there is right-padding // only check if there is right-padding
static_if<(RightPads::At(idim) != 0)>{}([&](auto) { static_if<(RightPads::At(idim) != 0)>{}([&](auto) {
flag = flag || idx_up[idim] >= LeftPads::At(idim) + LowLengths::At(idim); flag = flag || idx_up[idim] >= LeftPads::At(idim) + LowerLengths::At(idim);
}); });
}); });
...@@ -106,11 +102,11 @@ struct Pad ...@@ -106,11 +102,11 @@ struct Pad
} }
}; };
// LowLengths: Sequence<...> // LowerLengths: Sequence<...>
template <typename LowLengths> template <typename LowerLengths>
struct Merge struct Merge
{ {
static constexpr index_t nDimLow = LowLengths::Size(); static constexpr index_t nDimLow = LowerLengths::Size();
static constexpr index_t nDimUp = 1; static constexpr index_t nDimUp = 1;
using LowerIndex = MultiIndex<nDimLow>; using LowerIndex = MultiIndex<nDimLow>;
...@@ -120,12 +116,10 @@ struct Merge ...@@ -120,12 +116,10 @@ struct Merge
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; } __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
__host__ __device__ static constexpr auto GetUpperLengths() __host__ __device__ static constexpr auto GetUpperLengths()
{ {
return Sequence<accumulate_on_sequence( return Sequence<accumulate_on_sequence(
GetLowerLengths(), math::multiplies<index_t>{}, Number<1>{})>{}; LowerLengths{}, math::multiplies<index_t>{}, Number<1>{})>{};
} }
// emulate constexpr lambda // emulate constexpr lambda
...@@ -158,11 +152,11 @@ struct Merge ...@@ -158,11 +152,11 @@ struct Merge
constexpr auto pseudo_low_strides = constexpr auto pseudo_low_strides =
reverse_inclusive_scan_sequence( reverse_inclusive_scan_sequence(
GetLowerLengths().PopFront(), math::multiplies<index_t>{}, Number<1>{}) LowerLengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{}); .PushBack(Number<1>{});
// calculate index in each of the dimensions in the order of their dimension #if 1 // would these 2 versions be compiled to same ISA?
#if 1 // would compile to same ISA? // calculate index in each of the dimensions in the order of their dimension
static_for<0, nDimLow - 1, 1>{}( static_for<0, nDimLow - 1, 1>{}(
lambda_CalculateLowerIndex<decltype(pseudo_low_strides)>(itmp, idx_low)); lambda_CalculateLowerIndex<decltype(pseudo_low_strides)>(itmp, idx_low));
...@@ -176,16 +170,75 @@ struct Merge ...@@ -176,16 +170,75 @@ struct Merge
} }
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date // idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
// If idx_up_diff is known at compile-time, many calculations can be optimized
// away by compiler
// This function assume idx_low_old is not out-of-bound
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff, CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */, const UpperIndex& /* idx_up_old */,
const LowerIndex& idx_low_old) const LowerIndex& idx_low_old)
{ {
LowerIndex idx_low_diff; // do nothing if idx_up_diff == 0
if(idx_up_diff[0] == 0)
{
return make_zero_array<index_t, nDimLow>();
}
// not implemeneted // CalculateLowerIndex(idx_up_diff) has multiple integer divisions.
// If idx_up_diff is known at compile-time, the calculation can
// be done at compile-time. However, if idx_up_diff is only known
// at run-time, then the calculation will also be computed at
// run-time, and can be very expensive.
LowerIndex idx_low_new = idx_low_old + CalculateLowerIndex(idx_up_diff);
return idx_low_diff; if(idx_up_diff[0] > 0)
{
bool carry = false;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDimLow, 1>{}([&](auto ireverse) {
constexpr index_t i = nDimLow - 1 - ireverse;
if(carry)
{
++idx_low_new(i);
}
carry = false;
if(idx_low_new[i] >= LowerLengths::At(i))
{
idx_low_new(i) -= LowerLengths::At(i);
carry = true;
}
});
}
else if(idx_up_diff[0] < 0)
{
bool borrow = false;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDimLow, 1>{}([&](auto ireverse) {
constexpr index_t i = nDimLow - 1 - ireverse;
if(borrow)
{
--idx_low_new(i);
}
borrow = false;
if(idx_low_new[i] < 0)
{
idx_low_new(i) += LowerLengths::At(i);
borrow = true;
}
});
}
return idx_low_new - idx_low_old;
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return false; } __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
...@@ -198,12 +251,12 @@ struct Merge ...@@ -198,12 +251,12 @@ struct Merge
} }
}; };
// UpLengths: Sequence<...> // UpperLengths: Sequence<...>
template <typename UpLengths> template <typename UpperLengths>
struct Unmerge struct Unmerge
{ {
static constexpr index_t nDimLow = 1; static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpLengths::Size(); static constexpr index_t nDimUp = UpperLengths::Size();
using LowerIndex = MultiIndex<nDimLow>; using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>; using UpperIndex = MultiIndex<nDimUp>;
...@@ -212,23 +265,16 @@ struct Unmerge ...@@ -212,23 +265,16 @@ struct Unmerge
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; } __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetLowerLengths() __host__ __device__ static constexpr auto GetUpperLengths() { return UpperLengths{}; }
{
constexpr index_t low_length =
accumulate_on_sequence(UpLengths{}, math::multiplies<index_t>{}, Number<1>{});
return Sequence<low_length>{};
}
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{ {
LowerIndex idx_low{0}; LowerIndex idx_low{0};
constexpr auto pseudo_up_strides = constexpr auto pseudo_up_strides =
typename sequence_reverse_inclusive_scan<UpLengths, math::multiplies<index_t>, 1>:: reverse_inclusive_scan_sequence(
type{}; UpperLengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
static_for<0, nDimUp, 1>{}( static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low(0) += idx_up[idim] * pseudo_up_strides[idim]; }); [&](auto idim) { idx_low(0) += idx_up[idim] * pseudo_up_strides[idim]; });
...@@ -245,47 +291,45 @@ struct Unmerge ...@@ -245,47 +291,45 @@ struct Unmerge
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
// TODO: should this function be here? should it be specific for padding check?
__host__ __device__ static constexpr bool
IsUpperIndexInPaddingArea(const UpperIndex& /* idx_up */)
{
return false;
}
}; };
// UpLengths: Sequence<...> // UpperLengths: Sequence<...>
// Coefficients: Sequence<...> // Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp] // idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
template <index_t LowLength, typename UpLengths, typename Coefficients> template <typename UpperLengths, typename Coefficients>
struct Embed struct Embed
{ {
static constexpr index_t nDimLow = 1; static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpLengths::Size(); static constexpr index_t nDimUp = UpperLengths::Size();
using LowerIndex = MultiIndex<nDimLow>; using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>; using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ explicit constexpr Embed() __host__ __device__ explicit constexpr Embed()
{ {
static_assert(UpLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1, static_assert(UpperLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1,
"wrong! # of dimensions not consistent"); "wrong! # of dimensions not consistent");
constexpr index_t low_id_max =
Coefficients::Back() + accumulate_on_sequence(UpLengths{} * Coefficients::PopBack(),
math::plus<index_t>{},
Number<0>{});
static_assert(low_id_max < LowLength, "wrong! lower-id will go out of range");
} }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; } __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; } __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<LowLength>{}; } __host__ __device__ static constexpr auto GetUpperLengths() { return UpperLengths{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{ {
LowerIndex idx_low(Coefficients{}[nDimUp]); LowerIndex idx_low(Coefficients{}[nDimUp]);
static_for<0, nDimUp, 1>{}( static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low[0] += idx_up[idim] * Coefficients{}[idim]; }); [&](auto idim) { idx_low(0) += idx_up[idim] * Coefficients{}[idim]; });
return idx_low; return idx_low;
} }
...@@ -298,12 +342,18 @@ struct Embed ...@@ -298,12 +342,18 @@ struct Embed
LowerIndex idx_low_diff{0}; LowerIndex idx_low_diff{0};
static_for<0, nDimUp, 1>{}( static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low_diff[0] += idx_up_diff[idim] * Coefficients{}[idim]; }); [&](auto idim) { idx_low_diff(0) += idx_up_diff[idim] * Coefficients{}[idim]; });
return idx_low_diff; return idx_low_diff;
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool
IsUpperIndexInPaddingArea(const UpperIndex& /* idx_up */)
{
return false;
}
}; };
} // namespace ck } // namespace ck
......
...@@ -207,10 +207,12 @@ struct TransformedTensorDescriptor ...@@ -207,10 +207,12 @@ struct TransformedTensorDescriptor
return LowTensorDescriptor{}; return LowTensorDescriptor{};
} }
#if 0
__host__ __device__ static constexpr auto GetLowerLengths() __host__ __device__ static constexpr auto GetLowerLengths()
{ {
return GetLowerTensorDescriptor().GetLengths(); return GetLowerTensorDescriptor().GetLengths();
} }
#endif
struct lambda_GetUpperLengths struct lambda_GetUpperLengths
{ {
...@@ -383,35 +385,5 @@ struct TransformedTensorDescriptor ...@@ -383,35 +385,5 @@ struct TransformedTensorDescriptor
} }
}; };
template <index_t... Lengths, index_t... Strides>
__host__ __device__ constexpr auto make_native_tensor_descriptor(Sequence<Lengths...>,
Sequence<Strides...>)
{
return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{};
}
template <typename Lengths>
__host__ __device__ constexpr auto make_native_tensor_descriptor_packed(Lengths)
{
constexpr auto strides = reverse_inclusive_scan_sequence(
Lengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
return make_native_tensor_descriptor(Lengths{}, strides);
}
template <typename LowTensorDescriptor,
typename Transforms,
typename LowDimensionIds,
typename UpDimensionIds>
__host__ __device__ constexpr auto
transform_tensor_descriptor(LowTensorDescriptor, Transforms, LowDimensionIds, UpDimensionIds)
{
return TransformedTensorDescriptor<LowTensorDescriptor,
Transforms,
LowDimensionIds,
UpDimensionIds>{};
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -6,6 +6,96 @@ ...@@ -6,6 +6,96 @@
namespace ck { namespace ck {
template <typename Lengths>
__host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths)
{
return reverse_inclusive_scan_sequence(
Lengths{}.PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
}
template <typename Lengths, index_t Align>
__host__ __device__ constexpr auto calculate_tensor_strides_aligned(Lengths, Number<Align>)
{
constexpr index_t L_back_align =
Align * math::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align);
return calculate_tensor_strides_packed(
Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{}));
}
template <index_t... Lengths, index_t... Strides>
__host__ __device__ constexpr auto make_native_tensor_descriptor(Sequence<Lengths...>,
Sequence<Strides...>)
{
return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{};
}
template <typename Lengths>
__host__ __device__ constexpr auto make_native_tensor_descriptor_packed(Lengths)
{
constexpr auto strides = calculate_tensor_strides_packed(Lengths{});
return make_native_tensor_descriptor(Lengths{}, strides);
}
template <typename Lengths, index_t Align>
__host__ __device__ constexpr auto make_native_tensor_descriptor_aligned(Lengths, Number<Align>)
{
constexpr auto strides = calculate_tensor_strides_aligned(Lengths{}, Number<Align>{});
return make_native_tensor_descriptor(Lengths{}, strides);
}
template <typename LowTensorDescriptor,
typename Transforms,
typename LowDimensionIds,
typename UpDimensionIds>
__host__ __device__ constexpr auto
transform_tensor_descriptor(LowTensorDescriptor, Transforms, LowDimensionIds, UpDimensionIds)
{
return TransformedTensorDescriptor<LowTensorDescriptor,
Transforms,
LowDimensionIds,
UpDimensionIds>{};
}
template <typename LowerTensorDescriptor,
index_t... LowerLengths,
index_t... LowerDimensionIds,
index_t... UpperDimensionIds>
__host__ __device__ constexpr auto reorder_tensor_descriptor_impl(LowerTensorDescriptor,
Sequence<LowerLengths...>,
Sequence<LowerDimensionIds...>,
Sequence<UpperDimensionIds...>)
{
return TransformedTensorDescriptor<LowerTensorDescriptor,
Tuple<PassThrough<LowerLengths>...>,
Tuple<Sequence<LowerDimensionIds>...>,
Tuple<Sequence<UpperDimensionIds>...>>{};
}
template <typename LowerTensorDescriptor, typename MapLower2Upper>
__host__ __device__ constexpr auto
reorder_tensor_descriptor_given_lower2upper(LowerTensorDescriptor, MapLower2Upper)
{
static_assert(is_valid_sequence_map<MapLower2Upper>{},
"wrong! MapLower2Upper is not a valid map");
return reorder_tensor_descriptor_impl(
LowerTensorDescriptor{},
LowerTensorDescriptor::GetLengths(),
typename arithmetic_sequence_gen<0, LowerTensorDescriptor::GetNumOfDimension(), 1>::type{},
MapLower2Upper{});
}
template <typename LowerTensorDescriptor, typename MapUpper2Lower>
__host__ __device__ constexpr auto
reorder_tensor_descriptor_given_upper2lower(LowerTensorDescriptor, MapUpper2Lower)
{
return reorder_tensor_descriptor_given_lower2upper(
LowerTensorDescriptor{}, typename sequence_map_inverse<MapUpper2Lower>::type{});
}
template <typename... NativeDimensions> template <typename... NativeDimensions>
__host__ __device__ void __host__ __device__ void
print_tensor_descriptor(const char* s, const NativeTensorDescriptor<NativeDimensions...>& desc) print_tensor_descriptor(const char* s, const NativeTensorDescriptor<NativeDimensions...>& desc)
......
...@@ -951,10 +951,10 @@ struct ThreadwiseGenericTensorSliceCopy_v3r1 ...@@ -951,10 +951,10 @@ struct ThreadwiseGenericTensorSliceCopy_v3r1
// The dimension access order should be the same on src and dst. // The dimension access order should be the same on src and dst.
// It is designed for cases, where one of src and dst is register, and // It is designed for cases, where one of src and dst is register, and
// the other is device memory or LDS // the other is device memory or LDS
template <class SrcDesc, template <typename SrcDesc,
class DstDesc, typename DstDesc,
class SliceLengths, typename SliceLengths,
class DimAccessOrder, typename DimAccessOrder,
index_t VectorAccessDim, index_t VectorAccessDim,
index_t SrcDataPerAccess, index_t SrcDataPerAccess,
index_t DstDataPerAccess> index_t DstDataPerAccess>
......
...@@ -91,8 +91,8 @@ int main(int argc, char* argv[]) ...@@ -91,8 +91,8 @@ int main(int argc, char* argv[])
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
constexpr index_t HI = 34; constexpr index_t HI = 32;
constexpr index_t WI = 34; constexpr index_t WI = 32;
constexpr index_t K = 128; constexpr index_t K = 128;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -100,8 +100,8 @@ int main(int argc, char* argv[]) ...@@ -100,8 +100,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 0
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42% // cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment