Commit 8b306478 authored by Chao Liu's avatar Chao Liu
Browse files

added TensorAdaptor class and use it to implement cluster descriptor

parent 2178d1d8
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
// TODO remove dependency on deprecated tensor descriptor // TODO remove dependency on deprecated tensor descriptor
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_adaptor.hpp"
namespace ck { namespace ck {
...@@ -44,5 +45,27 @@ __host__ __device__ constexpr auto make_cluster_descriptor( ...@@ -44,5 +45,27 @@ __host__ __device__ constexpr auto make_cluster_descriptor(
return ClusterDescriptor<Lengths, decltype(order)>{}; return ClusterDescriptor<Lengths, decltype(order)>{};
} }
template <typename Lengths,
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
__host__ __device__ constexpr auto make_cluster_descriptor_v2(
Lengths, ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{})
{
constexpr auto reordered_lengths = Lengths::ReorderGivenNew2Old(ArrangeOrder{});
constexpr index_t ndim_low = reordered_lengths.Size();
constexpr auto low_lengths = generate_tuple(
[&](auto idim_low) { return Number<reordered_lengths[idim_low]>{}; }, Number<ndim_low>{});
constexpr auto transform = make_merge_transform(low_lengths);
constexpr auto low_dim_old_top_ids = ArrangeOrder{};
constexpr auto up_dim_new_top_ids = Sequence<0>{};
return make_simple_tensor_adaptor(
make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -376,13 +376,13 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, ...@@ -376,13 +376,13 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered); unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);
// put everything together // put everything together
const auto all_transforms = container_cat(old_tensor_desc.GetTransforms(), new_transforms); const auto all_transforms = container_concat(old_tensor_desc.GetTransforms(), new_transforms);
constexpr auto all_low_dim_hidden_idss = constexpr auto all_low_dim_hidden_idss =
container_cat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss); container_concat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss);
constexpr auto all_up_dim_hidden_idss = constexpr auto all_up_dim_hidden_idss =
container_cat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss); container_concat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss);
const auto element_space_size = old_tensor_desc.GetElementSpaceSize(); const auto element_space_size = old_tensor_desc.GetElementSpaceSize();
......
...@@ -19,7 +19,30 @@ template <typename Transforms, ...@@ -19,7 +19,30 @@ template <typename Transforms,
typename TopDimensionHiddenIds> typename TopDimensionHiddenIds>
struct TensorAdaptor struct TensorAdaptor
{ {
private: __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); }
__host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; }
__host__ __device__ static constexpr auto GetLowerDimensionHiddenIdss()
{
return LowerDimensionHiddenIdss{};
}
__host__ __device__ static constexpr auto GetUpperDimensionHiddenIdss()
{
return UpperDimensionHiddenIdss{};
}
__host__ __device__ static constexpr auto GetTopDimensionHiddenIds()
{
return TopDimensionHiddenIds{};
}
__host__ __device__ static constexpr auto GetBottomDimensionHiddenIds()
{
return BottomDimensionHiddenIds{};
}
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
{ {
const auto lengths = generate_tuple( const auto lengths = generate_tuple(
...@@ -71,6 +94,16 @@ struct TensorAdaptor ...@@ -71,6 +94,16 @@ struct TensorAdaptor
return make_tuple(itran_found, idim_up_found, found); return make_tuple(itran_found, idim_up_found, found);
} }
__host__ __device__ static constexpr index_t GetNumOfBottomDimension()
{
return BottomDimensionHiddenIds::Size();
}
__host__ __device__ static constexpr index_t GetNumOfTopDimension()
{
return TopDimensionHiddenIds::Size();
}
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension() __host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
{ {
constexpr auto all_low_dim_ids = constexpr auto all_low_dim_ids =
...@@ -92,8 +125,8 @@ struct TensorAdaptor ...@@ -92,8 +125,8 @@ struct TensorAdaptor
constexpr static index_t ntransform_ = GetNumOfTransform(); constexpr static index_t ntransform_ = GetNumOfTransform();
constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension(); constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension();
constexpr static index_t ndim_bottom_ = BottomDimensionHiddenIds::Size(); constexpr static index_t ndim_bottom_ = GetNumOfBottomDimension();
constexpr static index_t ndim_top_ = TopDimensionHiddenIds::Size(); constexpr static index_t ndim_top_ = GetNumOfTopDimension();
using HiddenIndex = MultiIndex<ndim_hidden_>; using HiddenIndex = MultiIndex<ndim_hidden_>;
using BottomIndex = MultiIndex<ndim_bottom_>; using BottomIndex = MultiIndex<ndim_bottom_>;
...@@ -170,25 +203,25 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -170,25 +203,25 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// shift // shift
constexpr index_t adaptor0_max_hidden_id = [&]() { constexpr index_t adaptor0_max_hidden_id = [&]() {
index_t adaptor0_max_hidden_id = NumericalMinValue<index_t>::value; index_t adaptor0_max_hidden_id = NumericLimits<index_t>::Min();
static_for<0, TensorAdaptor0::GetNumOfTransform(), 1>{}([&](auto itran) { static_for<0, TensorAdaptor0::GetNumOfTransform(), 1>{}([&](auto itran) {
constexpr index_t ndim_low = constexpr index_t ndim_low =
TensorAdaptor0::GetTransforms()[itran].GetNumOfLowerDimension(); TensorAdaptor0{}.GetTransforms()[itran].GetNumOfLowerDimension();
static_for<0, ndim_low, 1>{}([&](auto idim_low) { static_for<0, ndim_low, 1>{}([&](auto idim_low) {
adaptor0_max_hidden_id = adaptor0_max_hidden_id =
math::max(adaptor0_max_hidden_id, math::max(adaptor0_max_hidden_id,
TensorAdaptor0::GetLowerDimensionHiddenIdss()[itran][idim_low]); TensorAdaptor0::GetLowerDimensionHiddenIdss()[itran][idim_low].value);
}); });
constexpr index_t ndim_up = constexpr index_t ndim_up =
TensorAdaptor0::GetTransforms()[itran].GetNumOfUpperDimension(); TensorAdaptor0{}.GetTransforms()[itran].GetNumOfUpperDimension();
static_for<0, ndim_up, 1>{}([&](auto idim_up) { static_for<0, ndim_up, 1>{}([&](auto idim_up) {
adaptor0_max_hidden_id = adaptor0_max_hidden_id =
math::max(adaptor0_max_hidden_id, math::max(adaptor0_max_hidden_id,
TensorAdaptor0::GetUpperDimensionHiddenIdss()[itran][idim_up]); TensorAdaptor0::GetUpperDimensionHiddenIdss()[itran][idim_up].value);
}); });
}); });
...@@ -196,25 +229,25 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -196,25 +229,25 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
}(); }();
constexpr index_t adaptor1_min_hidden_id = [&]() { constexpr index_t adaptor1_min_hidden_id = [&]() {
index_t adaptor1_min_hidden_id = NumericalMaxValue<index_t>::value; index_t adaptor1_min_hidden_id = NumericLimits<index_t>::Max();
static_for<0, TensorAdaptor0::GetNumOfTransform(), 1>{}([&](auto itran) { static_for<0, TensorAdaptor1::GetNumOfTransform(), 1>{}([&](auto itran) {
constexpr index_t ndim_low = constexpr index_t ndim_low =
TensorAdaptor0::GetTransforms()[itran].GetNumOfLowerDimension(); TensorAdaptor1{}.GetTransforms()[itran].GetNumOfLowerDimension();
static_for<0, ndim_low, 1>{}([&](auto idim_low) { static_for<0, ndim_low, 1>{}([&](auto idim_low) {
adaptor1_min_hidden_id = adaptor1_min_hidden_id =
math::min(adaptor0_max_hidden_id, math::min(adaptor1_min_hidden_id,
TensorAdaptor0::GetLowerDimensionHiddenIdss()[itran][idim_low]); TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran][idim_low].value);
}); });
constexpr index_t ndim_up = constexpr index_t ndim_up =
TensorAdaptor0::GetTransforms()[itran].GetNumOfUpperDimension(); TensorAdaptor1{}.GetTransforms()[itran].GetNumOfUpperDimension();
static_for<0, ndim_up, 1>{}([&](auto idim_up) { static_for<0, ndim_up, 1>{}([&](auto idim_up) {
adaptor0_max_hidden_id = adaptor1_min_hidden_id =
math::min(adaptor0_max_hidden_id, math::min(adaptor1_min_hidden_id,
TensorAdaptor0::GetUpperDimensionHiddenIdss()[itran][idim_up]); TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran][idim_up].value);
}); });
}); });
...@@ -224,25 +257,32 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -224,25 +257,32 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
constexpr index_t adaptor1_hidden_id_shift = constexpr index_t adaptor1_hidden_id_shift =
adaptor1_min_hidden_id - adaptor0_max_hidden_id + 1; adaptor1_min_hidden_id - adaptor0_max_hidden_id + 1;
constexpr index_t ndim_bottom_1 = TensorAdaptor1::GetNumOfBottomDimension();
// all_low_dim_hidden_idss = // all_low_dim_hidden_idss =
// low_dim_hidden_idss_0 + shift_hidden_id_for_1(match_hidden_id_for_1(low_dim_hiden_idss_1)) // low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1))
constexpr auto low_dim_hidden_idss_1 = generate_tuple( constexpr auto low_dim_hidden_idss_1 = generate_tuple(
// generate sequence of ids for a transform // generate sequence of ids for a transform
[&](auto itran) { [&](auto itran) {
constexpr auto ndim_low_1 = constexpr auto ndim_low_1 = TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran].Size();
TensorAdpator1::GetLowerDimensionsHiddenIdss()[itran].Size();
constexpr auto low_dim_hidden_ids_1 = constexpr auto low_dim_hidden_ids_1 =
TensorAdpator1::GetLowerDimensionsHiddenIdss()[itran]; TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran];
// sequence in, sequence out // sequence in, sequence out
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
{ {
constexpr auto low_dim_hidden_ids_1_mod = to_multi_index(low_dim_hidden_ids_1); auto low_dim_hidden_ids_1_mod = to_multi_index(low_dim_hidden_ids_1);
// shift hidden id so every dim id is unique
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
low_dim_hidden_ids_1_mod(idim_low_1) -= adaptor1_hidden_id_shift;
});
// match hidden id // match hidden id
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) { static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) {
// if this low dim is bottom dim, then do id matching
if constexpr(low_dim_hidden_ids_1[idim_low_1] == if constexpr(low_dim_hidden_ids_1[idim_low_1] ==
TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1]) TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1])
{ {
...@@ -252,18 +292,13 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -252,18 +292,13 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
}); });
}); });
// shift hidden id return low_dim_hidden_ids_1_mod;
static_for<0, ndim_low_1, 1>{}[&](auto idim_low_1)
{
low_dim_hidden_ids_1_mod(idim_low_1) -= adaptor1_hidden_id_shift;
}
return generate_sequence([&](auto i) constexpr { return low_dim_hidden_ids_1[i]; },
Number<ndim_low_1>{});
} }
(); ();
return low_dim_hidden_ids_1_mod; return generate_sequence_v2(
[&](auto i) constexpr { return Number<low_dim_hidden_ids_1_mod[i]>{}; },
Number<ndim_low_1>{});
}, },
Number<TensorAdaptor1::GetNumOfTransform()>{}); Number<TensorAdaptor1::GetNumOfTransform()>{});
...@@ -275,29 +310,29 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -275,29 +310,29 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
constexpr auto up_dim_hidden_idss_1 = generate_tuple( constexpr auto up_dim_hidden_idss_1 = generate_tuple(
// generate sequence of ids for a transform // generate sequence of ids for a transform
[&](auto itran) { [&](auto itran) {
constexpr auto ndim_up_1 = TensorAdpator1::GetUpperDimensionsHiddenIdss()[itran].Size(); constexpr auto ndim_up_1 = TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran].Size();
constexpr auto up_dim_hidden_ids_1 = constexpr auto up_dim_hidden_ids_1 =
TensorAdpator1::GetUpperDimensionsHiddenIdss()[itran]; TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran];
// sequence in, sequence out // sequence in, constexpr tuple out
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
{ {
constexpr auto up_dim_hidden_ids_1_mod = to_multi_index(up_dim_hidden_ids_1); auto up_dim_hidden_ids_1_mod = to_multi_index(up_dim_hidden_ids_1);
// shift hidden id // shift hidden id
static_for<0, ndim_up_1, 1>{}[&](auto idim_up_1) static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
{
up_dim_hidden_ids_1_mod(idim_up_1) -= adaptor1_hidden_id_shift; up_dim_hidden_ids_1_mod(idim_up_1) -= adaptor1_hidden_id_shift;
} });
return generate_sequence( return up_dim_hidden_ids_1_mod;
[&](auto i) constexpr { return up_dim_hidden_ids_1_mod[i]; },
Number<ndim_up_1>{});
} }
(); ();
return up_dim_hidden_ids_1_mod; // constexpr tuple to sequence
return generate_sequence_v2(
[&](auto i) constexpr { return Number<up_dim_hidden_ids_1_mod[i]>{}; },
Number<ndim_up_1>{});
}, },
Number<TensorAdaptor1::GetNumOfTransform()>{}); Number<TensorAdaptor1::GetNumOfTransform()>{});
...@@ -305,10 +340,10 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -305,10 +340,10 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
container_concat(TensorAdaptor0::GetUpperDimensionHiddenIdss(), up_dim_hidden_idss_1); container_concat(TensorAdaptor0::GetUpperDimensionHiddenIdss(), up_dim_hidden_idss_1);
// bottom_dim_hidden_ids = bottom_dim_hidden_ids_0 // bottom_dim_hidden_ids = bottom_dim_hidden_ids_0
constexpr bottom_dim_hidden_ids = TensorAdaptor0::GetBottomDimensionHiddenIds(); constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::GetBottomDimensionHiddenIds();
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1) // top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
constexpr top_dim_hidden_ids = constexpr auto top_dim_hidden_ids =
TensorAdaptor1::GetTopDimensionHiddenIds() - Number<adaptor1_hidden_id_shift>{}; TensorAdaptor1::GetTopDimensionHiddenIds() - Number<adaptor1_hidden_id_shift>{};
// put everything together // put everything together
...@@ -327,12 +362,45 @@ __host__ __device__ constexpr auto make_simple_tensor_adaptor(const Transforms& ...@@ -327,12 +362,45 @@ __host__ __device__ constexpr auto make_simple_tensor_adaptor(const Transforms&
LowerDimensionOldTopIdss, LowerDimensionOldTopIdss,
UpperDimensionNewTopIdss) UpperDimensionNewTopIdss)
{ {
constexpr index_t ntransform = Transforms::Size();
static_assert(LowerDimensionOldTopIdss::Size() == ntransform &&
UpperDimensionNewTopIdss::Size() == ntransform,
"wrong!");
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
constexpr auto all_low_dim_old_top_ids =
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
LowerDimensionOldTopIdss{});
constexpr auto all_up_dim_new_top_ids =
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
UpperDimensionNewTopIdss{});
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
"wrong!");
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.Size();
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.Size();
// low_dim_hidden_idss // low_dim_hidden_idss
// up_dim_hidden_idss constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
constexpr auto up_dim_hidden_idss = generate_tuple(
[](auto itran) { return UpperDimensionNewTopIdss{}[itran] + Number<ndim_old_top>{}; },
Number<ntransform>{});
// bottom_dim_hidden_ids // bottom_dim_hidden_ids
constexpr auto bottom_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{};
// top_dim_hidden_ids // top_dim_hidden_ids
constexpr auto top_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number<ndim_old_top>{};
return TensorAdaptor<Transform, return TensorAdaptor<Transforms,
decltype(low_dim_hidden_idss), decltype(low_dim_hidden_idss),
decltype(up_dim_hidden_idss), decltype(up_dim_hidden_idss),
decltype(bottom_dim_hidden_ids), decltype(bottom_dim_hidden_ids),
......
...@@ -67,8 +67,8 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -67,8 +67,8 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_id = const auto thread_cluster_id = thread_cluster_desc_.CalculateBottomIndex(
thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id()); make_multi_index(get_thread_local_1d_id()));
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{}; const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
...@@ -142,7 +142,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -142,7 +142,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
} }
static constexpr auto thread_cluster_desc_ = static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseDynamicTensorSliceTransfer_v3<ThreadSliceLengths, ThreadwiseDynamicTensorSliceTransfer_v3<ThreadSliceLengths,
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "container_helper.hpp" #include "container_helper.hpp"
#include "statically_indexed_array.hpp" #include "statically_indexed_array.hpp"
#include "container_element_picker.hpp" #include "container_element_picker.hpp"
#include "data_type.hpp"
#include "float_type.hpp" #include "float_type.hpp"
#include "buffer.hpp" #include "buffer.hpp"
#include "functional.hpp" #include "functional.hpp"
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#define CK_DEVICE_BACKEND_AMD 1 #define CK_DEVICE_BACKEND_AMD 1
// GPU ID // GPU ID
#if 0 #if 1
#define CK_AMD_GPU_GFX906 1 #define CK_AMD_GPU_GFX906 1
#elif 0 #elif 0
#define CK_AMD_GPU_GFX908 1 #define CK_AMD_GPU_GFX908 1
......
...@@ -20,7 +20,8 @@ struct ContainerElementPicker ...@@ -20,7 +20,8 @@ struct ContainerElementPicker
__host__ __device__ constexpr ContainerElementPicker(Arr& array) : mArray{array} __host__ __device__ constexpr ContainerElementPicker(Arr& array) : mArray{array}
{ {
constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{}); constexpr index_t imax =
reduce_on_sequence(Picks{}, math::maximize<index_t>{}, Number<0>{});
static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
} }
...@@ -85,7 +86,8 @@ struct ConstantContainerElementPicker ...@@ -85,7 +86,8 @@ struct ConstantContainerElementPicker
__host__ __device__ constexpr ConstantContainerElementPicker(const Arr& array) : mArray{array} __host__ __device__ constexpr ConstantContainerElementPicker(const Arr& array) : mArray{array}
{ {
constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{}); constexpr index_t imax =
reduce_on_sequence(Picks{}, math::maximize<index_t>{}, Number<0>{});
static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
} }
......
...@@ -158,6 +158,7 @@ __host__ __device__ constexpr auto container_reduce_impl( ...@@ -158,6 +158,7 @@ __host__ __device__ constexpr auto container_reduce_impl(
} }
// rocm-4.1 compiler would crash for recursive lambda // rocm-4.1 compiler would crash for recursive lambda
// container reduce with initial value
template <typename Container, template <typename Container,
typename Reduce, typename Reduce,
typename Init, typename Init,
...@@ -176,23 +177,6 @@ __host__ __device__ constexpr auto container_reduce(const Container& x, ...@@ -176,23 +177,6 @@ __host__ __device__ constexpr auto container_reduce(const Container& x,
return container_reduce_impl( return container_reduce_impl(
x, reduce, init, Number<IBegin>{}, Number<IEnd>{}, Number<IStep>{}); x, reduce, init, Number<IBegin>{}, Number<IEnd>{}, Number<IStep>{});
} }
template <typename Container,
typename Reduce,
index_t IBegin = 0,
index_t IEnd = Container::Size(),
index_t IStep = 1>
__host__ __device__ constexpr auto container_reduce(const Container& x,
Reduce reduce,
Number<IBegin> = Number<0>{},
Number<IEnd> = Number<Container::Size()>{},
Number<IStep> = Number<1>{})
{
static_assert(IEnd > IBegin && (IEnd - IBegin) % IStep == 0, "wrong!");
return container_reduce_impl(
x, reduce, x[Number<IBegin>{}] Number<IBegin + 1>{}, Number<IEnd>{}, Number<IStep>{});
}
#endif #endif
template <typename TData, index_t NSize, typename Reduce> template <typename TData, index_t NSize, typename Reduce>
......
#ifndef CK_DATA_TYPE_HPP
#define CK_DATA_TYPE_HPP
namespace ck {
template <typename T>
struct NumericLimits;
template <>
struct NumericLimits<int32_t>
{
__host__ __device__ static constexpr int32_t Min()
{
return std::numeric_limits<int32_t>::min();
}
__host__ __device__ static constexpr int32_t Max()
{
return std::numeric_limits<int32_t>::max();
}
};
} // namespace ck
#endif
...@@ -64,7 +64,7 @@ int main(int argc, char* argv[]) ...@@ -64,7 +64,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 1080; constexpr index_t HI = 1080;
...@@ -630,7 +630,7 @@ int main(int argc, char* argv[]) ...@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{})); print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 0 #if 1
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using acc_data_t = float; using acc_data_t = float;
...@@ -724,7 +724,7 @@ int main(int argc, char* argv[]) ...@@ -724,7 +724,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
in_vector_size, in_vector_size,
acc_data_t, acc_data_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment