Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
f00c1381
Commit
f00c1381
authored
Sep 20, 2019
by
Chao Liu
Browse files
adding logic to judge linear dimension
parent
bf7e7d62
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
5 deletions
+75
-5
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+44
-4
composable_kernel/include/utility/sequence.hpp
composable_kernel/include/utility/sequence.hpp
+31
-1
No files found.
composable_kernel/include/tensor_description/tensor_descriptor.hpp
View file @
f00c1381
...
...
@@ -256,7 +256,7 @@ struct TransformedTensorDescriptor
constexpr
auto
sorted2unsorted_map
=
typename
sort_up_dimension_ids
::
sorted2unsorted_map
{};
constexpr
auto
sorted_up_lengths
=
pick_sequence_elements
(
mingled_up_lengths
,
sorted2unsorted_map
);
pick_sequence_elements
_by_ids
(
mingled_up_lengths
,
sorted2unsorted_map
);
return
sorted_up_lengths
;
}
...
...
@@ -347,20 +347,60 @@ struct TransformedTensorDescriptor
}
#if 0
struct lambda_sequence_logic_or
{
template <typename... Seqs>
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
{
// TODO: should use math::logic_or<bool>, after Sequence can take bool
return typename sequence_reduce<math::logic_or<index_t>, Seqs...>::type{};
}
};
struct lambda_1
{
template <typename Transform>
__host__ __device__ constexpr auto operator()(const Transform& tran) const
{
return tran.GetUpperLengths();
}
};
template <index_t IDim>
__host__ __device__ static constexpr bool GetMaskOfLinearDimensions()
{
// create tuple of linear dimension masks, for all transformations
constexpr auto tuple_of_linear_dimension_mask =
transform_tuple(lambda_1, Transforms{});
// reduce tuple of masks into one mask
constexpr auto linear_dimension_mask =
unpack(lambda_sequence_logic_or{}, tuple_of_linear_dimension_mask);
return linear_dimension_mask;
}
template <index_t IDim>
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
{
// not implemented
return GetMaskOfLinearDimensions().At(Number<IDim>{});
}
__host__ __device__ static constexpr auto GetLinearDimensions()
{
// not implemented
constexpr auto linear_dimension_mask = GetMaskOfLienarDimensions();
return pick_sequence_elements_by_mask(
typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, linear_dimension_mask);
}
__host__ __device__ static constexpr auto GetNonLinearDimensions()
{
// not implemented
constexpr auto nonlinear_dimension_mask =
GetMaskOfLienarDimensions().Transform(math::logic_not<index_t>{});
return pick_sequence_elements_by_mask(
typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, nonlinear_dimension_mask);
}
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
...
...
composable_kernel/include/utility/sequence.hpp
View file @
f00c1381
...
...
@@ -311,6 +311,28 @@ struct sequence_reverse<Sequence<I0, I1>>
using
type
=
Sequence
<
I1
,
I0
>
;
};
#if 0
template <typename Reduce, typename Seq, typename... Seqs>
struct sequence_reduce
{
using type = typename sequence_reduce<Reduce,
Seq,
typename sequence_reduce<Reduce, Seqs...>::type>::type;
};
template <typename Reduce, index_t... Xs, index_t... Ys>
struct sequence_reduce<Reduce, Sequence<Xs...>, Sequence<Ys...>>
{
using type = Sequence<Reduce{}(Xs, Ys)...>;
};
template <typename Reduce, typename Seq>
struct sequence_reduce<Reduce, Seq>
{
using type = Seq;
};
#endif
template
<
typename
Values
,
typename
Ids
,
typename
Compare
>
struct
sequence_sort_impl
{
...
...
@@ -728,11 +750,19 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
}
template
<
typename
Seq
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
pick_sequence_elements
(
Seq
,
Sequence
<
Is
...
>
)
__host__
__device__
constexpr
auto
pick_sequence_elements
_by_ids
(
Seq
,
Sequence
<
Is
...
>
/* ids */
)
{
return
Sequence
<
Seq
::
At
(
Number
<
Is
>
{})...
>
{};
}
#if 0
template <typename Seq, typename Mask>
__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
{
// not implemented
}
#endif
template
<
typename
Seq
,
typename
Reduce
>
struct
lambda_accumulate_on_sequence
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment