Commit 625838de authored by Chao Liu's avatar Chao Liu
Browse files

added tuple

parent 12da8154
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_batched_gemm.hpp" #include "blockwise_batched_gemm.hpp"
...@@ -45,6 +47,7 @@ template <index_t GridSize, ...@@ -45,6 +47,7 @@ template <index_t GridSize,
index_t OutThreadCopyDataPerAccess_N> index_t OutThreadCopyDataPerAccess_N>
struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
{ {
#if 0
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
...@@ -478,6 +481,67 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded ...@@ -478,6 +481,67 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
#endif #endif
}); });
} }
#else
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
#if 0
constexpr auto tmp = std::tuple<bool>{};
constexpr auto flag = std::get<0>(tmp);
#else
constexpr auto a = Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>{}, 99);
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
printf("adsas %d\n", a.At(Number<0>{}));
print_Sequence("seq", a.At(Number<1>{}));
printf("adsas %lu\n", a.At(Number<2>{}));
}
auto b = Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>{}, 99);
b.At(Number<0>{}) = false;
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
printf("adsas %d\n", b.At(Number<0>{}));
print_Sequence("seq", b.At(Number<1>{}));
printf("adsas %lu\n", b.At(Number<2>{}));
}
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
printf("adsas %d\n",
Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>(), 99).At(Number<0>{}));
print_Sequence(
"seq", Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>(), 99).At(Number<1>{}));
printf("adsas %d\n",
Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>(), 99).At(Number<2>{}));
}
#endif
#if 0
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
// create a native tensor descriptor
constexpr auto in_n_c_h_w_global_desc =
make_NativeTensorDescriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_tensor_descriptor("in_n_c_h_w_global_desc", in_n_c_h_w_global_desc);
}
// transform the tensor descriptor once
//
// calculate the offset of some entry
#endif
}
#endif
}; };
} // namespace ck } // namespace ck
......
...@@ -12,15 +12,17 @@ struct Dimension ...@@ -12,15 +12,17 @@ struct Dimension
}; };
template <index_t Length, index_t Stride> template <index_t Length, index_t Stride>
struct NativeDimension : Dimension<Length> struct NativeDimension
{ {
__host__ __device__ static constexpr auto GetLength() { return Number<Length>{}; }
__host__ __device__ static constexpr auto GetStride() { return Number<Stride>{}; } __host__ __device__ static constexpr auto GetStride() { return Number<Stride>{}; }
__host__ __device__ static constexpr index_t GetOffset(index_t id) { return id * Stride; } __host__ __device__ static constexpr index_t GetOffset(index_t i) { return i * Stride; }
__host__ __device__ static constexpr index_t GetOffsetDiff(index_t id_diff) __host__ __device__ static constexpr index_t GetOffsetDiff(index_t i_diff)
{ {
return id_diff * Stride; return i_diff * Stride;
} }
}; };
......
...@@ -8,25 +8,19 @@ namespace ck { ...@@ -8,25 +8,19 @@ namespace ck {
template <index_t N> template <index_t N>
using MultiIndex = Array<index_t, N>; using MultiIndex = Array<index_t, N>;
// LowLengths: Sequence<...> template <index_t Length>
template <class LowLengths>
struct PassThrough struct PassThrough
{ {
static constexpr index_t nDim = LowLengths::GetSize(); using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
using LowerIndex = MultiIndex<nDim>;
using UpperIndex = LowerIndex;
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; } __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<1>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<1>{}; }
{
return GetNumOfLowerDimension();
}
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; } __host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<Length>{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return GetLowerLengths(); } __host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<Length>{}; }
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) { return idx_up; } __host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) { return idx_up; }
...@@ -35,7 +29,7 @@ struct PassThrough ...@@ -35,7 +29,7 @@ struct PassThrough
return idx_up_diff; return idx_up_diff;
} }
__host__ __device__ static constexpr bool IsIndexTransformLinear() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
}; };
// LowLengths: Sequence<...> // LowLengths: Sequence<...>
...@@ -45,25 +39,22 @@ struct Pad ...@@ -45,25 +39,22 @@ struct Pad
static constexpr index_t nDim = LowLengths::GetSize(); static constexpr index_t nDim = LowLengths::GetSize();
using LowerIndex = MultiIndex<nDim>; using LowerIndex = MultiIndex<nDim>;
using UpperIndex = LowerIndex; using UpperIndex = MultiIndex<nDim>;
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; } __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
{
return GetNumOfLowerDimension();
}
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; } __host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
__host__ __device__ static constexpr auto GetUpperLengths() __host__ __device__ static constexpr auto GetUpperLengths()
{ {
return GetLowerLengths() + LeftPads + RightPads; return GetLowerLengths() + LeftPads{} + RightPads{};
} }
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) __host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
{ {
return idx_up - LeftPads; return idx_up - LeftPads{};
} }
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff) __host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff)
...@@ -71,9 +62,10 @@ struct Pad ...@@ -71,9 +62,10 @@ struct Pad
return idx_up_diff; return idx_up_diff;
} }
__host__ __device__ static constexpr bool IsIndexTransformLinear() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
}; };
#if 0
// LowLengths: Sequence<...> // LowLengths: Sequence<...>
template <class LowLengths> template <class LowLengths>
struct Merge struct Merge
...@@ -116,8 +108,9 @@ struct Merge ...@@ -116,8 +108,9 @@ struct Merge
return idx_low_diff; return idx_low_diff;
} }
__host__ __device__ static constexpr bool IsIndexTransformLinear() { return false; } __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
}; };
#endif
// UpLengths: Sequence<...> // UpLengths: Sequence<...>
template <index_t LowLength, class UpLengths> template <index_t LowLength, class UpLengths>
...@@ -126,6 +119,9 @@ struct Unmerge ...@@ -126,6 +119,9 @@ struct Unmerge
static constexpr index_t nDimLow = 1; static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpLengths::GetSize(); static constexpr index_t nDimUp = UpLengths::GetSize();
using UpperIndex = MultiIndex<nDimUp>;
using LowerIndex = MultiIndex<nDimLow>;
__host__ __device__ constexpr Unmerge() __host__ __device__ constexpr Unmerge()
{ {
static_assert(LowLength == accumulate_on_sequence( static_assert(LowLength == accumulate_on_sequence(
...@@ -133,7 +129,7 @@ struct Unmerge ...@@ -133,7 +129,7 @@ struct Unmerge
"wrong! UpLengths need to be "); "wrong! UpLengths need to be ");
} }
__host__ __device__ static constexpr auto GetNumOfUpperDimension(){return Number<nDimUp>{}}; __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; } __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
...@@ -149,7 +145,7 @@ struct Unmerge ...@@ -149,7 +145,7 @@ struct Unmerge
LowerIndex idx_low{0}; LowerIndex idx_low{0};
static_for<0, nDim, 1>{}([&](auto idim) { idx_low[0] += idx_up[idim] * scans[idim]; }); static_for<0, nDimUp, 1>{}([&](auto idim) { idx_low(0) += idx_up[idim] * scans[idim]; });
return idx_low; return idx_low;
} }
...@@ -159,7 +155,7 @@ struct Unmerge ...@@ -159,7 +155,7 @@ struct Unmerge
return GetLowerIndex(idx_up_diff); return GetLowerIndex(idx_up_diff);
} }
__host__ __device__ static constexpr bool IsIndexTransformLinear() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
}; };
// UpLengths: Sequence<...> // UpLengths: Sequence<...>
...@@ -171,7 +167,8 @@ struct Embed ...@@ -171,7 +167,8 @@ struct Embed
static constexpr index_t nDimLow = 1; static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpLengths::GetSize(); static constexpr index_t nDimUp = UpLengths::GetSize();
static constexpr auto mCoefficients = Coefficients{}; using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ constexpr Embed() __host__ __device__ constexpr Embed()
{ {
...@@ -179,14 +176,14 @@ struct Embed ...@@ -179,14 +176,14 @@ struct Embed
"wrong! # of dimensions not consistent"); "wrong! # of dimensions not consistent");
constexpr index_t low_id_max = constexpr index_t low_id_max =
Coefficents.Back() + accumulate_on_sequence(UpLengths{} * Coefficients::PopBack(), Coefficients::Back() + accumulate_on_sequence(UpLengths{} * Coefficients::PopBack(),
math::plus<index_t>{}, math::plus<index_t>{},
Number<0>{}); Number<0>{});
static_assert(low_id_max < LowLength, "wrong! lower-id will go out of range"); static_assert(low_id_max < LowLength, "wrong! lower-id will go out of range");
} }
__host__ __device__ static constexpr auto GetNumOfUpperDimension(){return Number<nDimUp>{}}; __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; } __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
...@@ -196,10 +193,10 @@ struct Embed ...@@ -196,10 +193,10 @@ struct Embed
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) __host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
{ {
LowerIndex idx_low{mCoefficients[nDimUp]}; LowerIndex idx_low(Coefficients{}[nDimUp]);
static_for<0, nDimUp, 1>{}( static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low[0] += idx_up[idim] * mCoefficients[idim]; }); [&](auto idim) { idx_low[0] += idx_up[idim] * Coefficients{}[idim]; });
return idx_low; return idx_low;
} }
...@@ -209,12 +206,12 @@ struct Embed ...@@ -209,12 +206,12 @@ struct Embed
LowerIndex idx_low_diff{0}; LowerIndex idx_low_diff{0};
static_for<0, nDimUp, 1>{}( static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low_diff[0] += idx_up_diff[idim] * mCoefficients[idim]; }); [&](auto idim) { idx_low_diff[0] += idx_up_diff[idim] * Coefficients{}[idim]; });
return idx_low_diff; return idx_low_diff;
} }
__host__ __device__ static constexpr bool IsIndexTransformLinear() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
}; };
} // namespace ck } // namespace ck
......
...@@ -11,21 +11,39 @@ template <class... NativeDimensions> ...@@ -11,21 +11,39 @@ template <class... NativeDimensions>
struct NativeTensorDescriptor struct NativeTensorDescriptor
{ {
using type = NativeTensorDescriptor; using type = NativeTensorDescriptor;
static constexpr auto mDimensions = Tuple<NativeDimensions...>; static constexpr auto mDimensions = Tuple<NativeDimensions...>{};
static constexpr index_t nDim = mDimensions::GetSize(); static constexpr index_t nDim = mDimensions.GetSize();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; } __host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
struct lambda_GetLength
{
template <class IDim>
__host__ __device__ constexpr auto operator()(IDim) const
{
return GetLength(IDim{});
}
};
__host__ __device__ static constexpr auto GetLengths() __host__ __device__ static constexpr auto GetLengths()
{ {
// not implemented return typename sequence_gen<nDim, lambda_GetLength>::type{};
} }
struct lambda_GetStride
{
template <class IDim>
__host__ __device__ constexpr auto operator()(IDim) const
{
return GetStride(IDim{});
}
};
__host__ __device__ static constexpr auto GetStrides() __host__ __device__ static constexpr auto GetStrides()
{ {
// not implemented return typename sequence_gen<nDim, lambda_GetStride>::type{};
} }
template <index_t IDim> template <index_t IDim>
...@@ -59,20 +77,26 @@ struct NativeTensorDescriptor ...@@ -59,20 +77,26 @@ struct NativeTensorDescriptor
return offset_diff; return offset_diff;
} }
__host__ __device__ static constexpr auto AreUpperIndex2OffsetTransformLinear(); template <index_t IDim>
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
{ {
// TODO: re-implement "Sequence", so that it can take other data-type (including bool) as return true;
// element
return uniform_sequence_gen<nDim, 1>{};
} }
__host__ __device__ static constexpr auto GetIndependentDimensionGroups() __host__ __device__ static constexpr auto GetLinearDimensions()
{ {
// not implemented, should return Tuple<Sequence<0>, Sequence<1>, ...> return typename arithmetic_sequence_gen<0, nDim, 1>::type{};
return xxx; }
__host__ __device__ static constexpr auto GetNonLinearDimensions() { return Sequence<>{}; }
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
{
return Tuple<>{};
} }
}; };
#if 0
// LowerTensorDescriptor // LowerTensorDescriptor
// Transforms: std::tuple<DimensionTransforms...> // Transforms: std::tuple<DimensionTransforms...>
// LowerDimensionIds: std::tuple<Sequence<...>> // LowerDimensionIds: std::tuple<Sequence<...>>
...@@ -213,16 +237,45 @@ struct TransformedTensorDescriptor ...@@ -213,16 +237,45 @@ struct TransformedTensorDescriptor
return GetLowerTensorDescriptor().GetOffset(GetLowerIndex(idx_up)); return GetLowerTensorDescriptor().GetOffset(GetLowerIndex(idx_up));
} }
__host__ __device__ static constexpr auto AreUpperIndex2OffsetTransformLinear(); template <index_t IDim>
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>);
{
// not implemented
}
__host__ __device__ static constexpr auto GetLinearDimensions()
{
// not implemented
}
__host__ __device__ static constexpr auto GetNonLinearDimensions()
{ {
// not implemented // not implemented
} }
__host__ __device__ static constexpr auto GetIndependentDimensionGroups() __host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
{ {
// not implemented // not implemented
} }
}; };
#endif
template <index_t... Lengths, index_t... Strides>
__host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence<Lengths...>,
Sequence<Strides...>)
{
return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{};
}
template <class Lengths>
__host__ __device__ constexpr auto make_NativeTensorDescriptor_packed(Lengths)
{
constexpr index_t strides = reverse_inclusive_scan_sequence(
Lengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
return make_NativeTensorDescriptor(Lengths{}, strides);
}
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
namespace ck {
template <class... NativeDimensions>
__host__ __device__ void print_tensor_descriptor(const char* s,
NativeTensorDescriptor<NativeDimensions...> desc)
{
print_tensor_descriptor_impl(s, desc.GetLengths(), desc.GetStrides());
}
template <index_t... Lengths, index_t... Strides>
__host__ __device__ void
print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strides...>)
{
constexpr index_t nDim = sizeof...(Lengths);
static_assert(nDim > 0 && nDim <= 12, "wrong!");
static_if<nDim == 1>{}([&](auto) {
printf("%s dim %u, lengths {%u}, strides {%u}\n", s, nDim, Lengths..., Strides...);
});
static_if<nDim == 2>{}([&](auto) {
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, nDim, Lengths..., Strides...);
});
static_if<nDim == 3>{}([&](auto) {
printf(
"%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", s, nDim, Lengths..., Strides...);
});
static_if<nDim == 4>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
s,
nDim,
Lengths...,
Strides...);
});
static_if<nDim == 5>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
s,
nDim,
Lengths...,
Strides...);
});
static_if<nDim == 6>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
s,
nDim,
Lengths...,
Strides...);
});
static_if<nDim == 7>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
s,
nDim,
Lengths...,
Strides...);
});
static_if<nDim == 8>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
s,
nDim,
Lengths...,
Strides...);
});
static_if<nDim == 9>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
"%u}\n",
s,
nDim,
Lengths...,
Strides...);
});
static_if<nDim == 10>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
"%u %u %u}\n",
s,
nDim,
Lengths...,
Strides...);
});
static_if<nDim == 11>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u "
"%u %u "
"%u %u %u}\n",
s,
nDim,
Lengths...,
Strides...);
});
static_if<nDim == 12>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u "
"%u %u %u %u "
"%u %u %u}\n",
s,
nDim,
Lengths...,
Strides...);
});
}
} // namespace ck
#endif
...@@ -85,6 +85,7 @@ struct TensorVisit ...@@ -85,6 +85,7 @@ struct TensorVisit
{ {
constexpr auto nonlinear_independent_dimensions_igroup = constexpr auto nonlinear_independent_dimensions_igroup =
nonlinear_independent_dimension_groups.Get(igroup); nonlinear_independent_dimension_groups.Get(igroup);
constexpr auto nonlinear_independent_lengths_igroup = constexpr auto nonlinear_independent_lengths_igroup =
lambda_HackLengths{}(lengths, nonlinear_independent_dimensions_igroup); lambda_HackLengths{}(lengths, nonlinear_independent_dimensions_igroup);
......
...@@ -82,9 +82,11 @@ struct Array ...@@ -82,9 +82,11 @@ struct Array
// A: Array // A: Array
// Picks: Sequence<...> // Picks: Sequence<...>
template <class Arr, class Picks> template <class Arr, class Picks>
ArrayElementPicker struct ArrayElementPicker
{ {
__host__ __device__ constexpr ArrayElementPicker(Arr & array) : mData{array} using data_type = typename Arr::data_type;
__host__ __device__ constexpr ArrayElementPicker(Arr& array) : mData{array}
{ {
constexpr index_t imax = constexpr index_t imax =
accumulate_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{}); accumulate_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
...@@ -95,26 +97,26 @@ ArrayElementPicker ...@@ -95,26 +97,26 @@ ArrayElementPicker
__host__ __device__ static constexpr index_t GetSize() { return Picks::GetSize(); } __host__ __device__ static constexpr index_t GetSize() { return Picks::GetSize(); }
template <index_t I> template <index_t I>
__host__ __device__ constexpr TData operator[](Number<I>) const __host__ __device__ constexpr data_type operator[](Number<I>) const
{ {
constexpr auto IP = Picks::Get(Number<I>{}); constexpr auto IP = Picks::Get(Number<I>{});
return mData[IP]; return mData[IP];
} }
__host__ __device__ constexpr TData operator[](index_t i) const __host__ __device__ constexpr data_type operator[](index_t i) const
{ {
constexpr index_t ip = Picks{}[i]; constexpr index_t ip = Picks{}[i];
return mData[ip]; return mData[ip];
} }
template <index_t I> template <index_t I>
__host__ __device__ TData& operator()(Number<I>) __host__ __device__ data_type& operator()(Number<I>)
{ {
constexpr auto IP = Picks::Get(Number<I>{}); constexpr auto IP = Picks::Get(Number<I>{});
return mData[IP]; return mData[IP];
} }
__host__ __device__ TData& operator()(index_t i) __host__ __device__ data_type& operator()(index_t i)
{ {
constexpr index_t ip = Picks{}[i]; constexpr index_t ip = Picks{}[i];
return mData[ip]; return mData[ip];
......
...@@ -2,66 +2,99 @@ ...@@ -2,66 +2,99 @@
#define CK_TUPLE_HPP #define CK_TUPLE_HPP
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "Sequence.hpp"
namespace ck { namespace ck {
template <class... Ts> namespace detail {
struct tuple : public std::tuple<Ts...>
{
using type = tuple;
__host__ __device__ static constexpr index_t GetSize() { return std::tuple_size(tuple{}); } template <index_t>
struct TupleElementKey
{
};
template <index_t I> template <typename Key, typename Data>
__host__ __device__ constexpr auto Get(Number<I>) const struct TupleElement
{
template <typename T>
__host__ __device__ explicit constexpr TupleElement(T&& v) : mData(static_cast<T&&>(v))
{ {
return std::get<I>(*this);
} }
template <index_t I> Data mData;
__host__ __device__ constexpr auto operator[](Number<I>) const
{
return Get(Number<I>{}) :
}
}; };
// merge tuple template <typename Key, typename Data>
template <class... Tuples> __host__ __device__ constexpr const Data& get_tuple_element(const TupleElement<Key, Data>& x)
__host__ __device__ constexpr auto merge_tuple(Tuples&&... xs)
{ {
return std::tuple_cat(xs...); return x.mData;
}; }
// generate sequence template <typename Key, typename Data>
template <index_t IBegin, index_t NRemain, class F> __host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x)
struct tuple_gen_impl
{ {
static constexpr index_t NRemainLeft = NRemain / 2; return x.mData;
static constexpr index_t NRemainRight = NRemain - NRemainLeft; }
static constexpr index_t IMiddle = IBegin + NRemainLeft;
using type = template <typename Key, typename Data>
typename tuple_merge<typename tuple_gen_impl<IBegin, NRemainLeft, F>::type, __host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x)
typename tuple_gen_impl<IMiddle, NRemainRight, F>::type>::type;
};
template <index_t I, class F>
struct tuple_gen_impl<I, 1, F>
{ {
static constexpr auto x = F{}(Number<I>{}); return static_cast<Data&&>(x.mData);
using type = tuple<Is>; }
};
template <index_t I, class F> template <typename Indices, typename... Xs>
struct sequence_gen_impl<I, 0, F> struct TupleImpl;
template <index_t... Is, typename... Xs>
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>...
{ {
using type = Sequence<>; template <typename... Ys>
__host__ __device__ explicit constexpr TupleImpl(Ys&&... ys)
: TupleElement<TupleElementKey<Is>, Xs>(static_cast<Ys&&>(ys))...
{
}
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
template <index_t I>
__host__ __device__ constexpr const auto& GetElementByKey(TupleElementKey<I>) const
{
return get_tuple_element<TupleElementKey<I>>(*this);
}
template <index_t I>
__host__ __device__ constexpr auto& GetElementByKey(TupleElementKey<I>)
{
return get_tuple_element<TupleElementKey<I>>(*this);
}
}; };
template <index_t NSize, class F> } // namespace detail
struct sequence_gen
template <typename... Xs>
struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>
{ {
using type = typename sequence_gen_impl<0, NSize, F>::type; using base =
detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>;
template <typename... Ys>
__host__ __device__ explicit constexpr Tuple(Ys&&... ys) : base(static_cast<Ys&&>(ys)...)
{
}
template <index_t I>
__host__ __device__ constexpr const auto& At(Number<I>) const
{
static_assert(I < base::Size(), "wrong! out of range");
return GetElementByKey(detail::TupleElementKey<I>{});
}
template <index_t I>
__host__ __device__ constexpr auto& At(Number<I>)
{
static_assert(I < base::Size(), "wrong! out of range");
return GetElementByKey(detail::TupleElementKey<I>{});
}
}; };
} // namespace ck } // namespace ck
......
...@@ -65,9 +65,6 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw, ...@@ -65,9 +65,6 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
index_t h_pad_low = LowerPads{}.Get(Number<0>{}); index_t h_pad_low = LowerPads{}.Get(Number<0>{});
index_t w_pad_low = LowerPads{}.Get(Number<1>{}); index_t w_pad_low = LowerPads{}.Get(Number<1>{});
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
auto f = [&](auto n, auto k, auto ho, auto wo) { auto f = [&](auto n, auto k, auto ho, auto wo) {
double v = 0; double v = 0;
for(int c = 0; c < wei_kcyx.mDesc.GetLengths()[1]; ++c) for(int c = 0; c < wei_kcyx.mDesc.GetLengths()[1]; ++c)
...@@ -125,9 +122,6 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw, ...@@ -125,9 +122,6 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
index_t h_pad_low = LowerPads{}.Get(Number<0>{}); index_t h_pad_low = LowerPads{}.Get(Number<0>{});
index_t w_pad_low = LowerPads{}.Get(Number<1>{}); index_t w_pad_low = LowerPads{}.Get(Number<1>{});
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
std::size_t HiPerTile = HoPerTile + Y - 1; std::size_t HiPerTile = HoPerTile + Y - 1;
std::size_t WiPerTile = WoPerTile + X - 1; std::size_t WiPerTile = WoPerTile + X - 1;
......
...@@ -368,7 +368,7 @@ int main(int argc, char* argv[]) ...@@ -368,7 +368,7 @@ int main(int argc, char* argv[])
#if 0 #if 0
device_convolution_direct_v2_nchw_kcyx_nkhw device_convolution_direct_v2_nchw_kcyx_nkhw
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
#elif 1 #elif 0
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn( device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(
in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
#elif 1 #elif 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment