Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
0271338e
Commit
0271338e
authored
Aug 06, 2019
by
Chao Liu
Browse files
added ReorderGiveOld2New() in Sequence and ConstantTensorDescriptor
parent
fdcfae3a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
41 deletions
+45
-41
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
...l/include/tensor_description/ConstantTensorDescriptor.hpp
+7
-0
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+2
-2
composable_kernel/include/utility/Sequence.hpp
composable_kernel/include/utility/Sequence.hpp
+36
-39
No files found.
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
View file @
0271338e
...
...
@@ -419,6 +419,13 @@ struct ConstantTensorDescriptor
return
ConstantTensorDescriptor
<
decltype
(
Lengths
::
ReorderGivenNew2Old
(
MapNew2Old
{})),
decltype
(
Strides
::
ReorderGivenNew2Old
(
MapNew2Old
{}))
>
{};
}
template
<
class
MapOld2New
>
__host__
__device__
static
constexpr
auto
ReorderGivenOld2New
(
MapOld2New
)
{
return
ConstantTensorDescriptor
<
decltype
(
Lengths
::
ReorderGivenOld2New
(
MapOld2New
{})),
decltype
(
Strides
::
ReorderGivenOld2New
(
MapOld2New
{}))
>
{};
}
};
template
<
class
Lengths
>
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
0271338e
...
...
@@ -74,8 +74,8 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
constexpr
auto
data_multi_id_in_access_order
=
access_multi_id
.
Modify
(
Number
<
nDim
-
1
>
{},
Number
<
itmp
>
{});
constexpr
auto
data_multi_id
=
reorder_array_given_old2new
(
sequence2array
(
data_multi_id_in_access_order
),
DimAccessOrder
{});
constexpr
auto
data_multi_id
=
data_multi_id_in_access_order
.
ReorderGivenOld2New
(
DimAccessOrder
{});
const
index_t
src_index
=
SrcDesc
::
GetOffsetFromMultiIndex
(
src_multi_id_begin
+
data_multi_id
);
...
...
composable_kernel/include/utility/Sequence.hpp
View file @
0271338e
...
...
@@ -6,12 +6,27 @@
namespace
ck
{
template
<
index_t
...>
struct
Sequence
;
template
<
class
Seq
,
index_t
I
>
struct
sequence_split
;
template
<
class
>
struct
is_valid_
sequence_
map
;
struct
sequence_
reverse
;
template
<
class
>
struct
sequence_map_inverse
;
template
<
class
>
struct
is_valid_sequence_map
;
template
<
index_t
I
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
sequence_pop_front
(
Sequence
<
I
,
Is
...
>
);
template
<
class
Seq
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Seq
);
template
<
index_t
...
Is
>
struct
Sequence
{
...
...
@@ -71,7 +86,10 @@ struct Sequence
return
ReorderGivenNew2Old
(
typename
sequence_map_inverse
<
MapOld2New
>::
type
{});
}
__host__
__device__
static
constexpr
auto
Reverse
();
__host__
__device__
static
constexpr
auto
Reverse
()
{
return
typename
sequence_reverse
<
Type
>::
type
{};
}
__host__
__device__
static
constexpr
auto
Front
()
{
...
...
@@ -85,9 +103,9 @@ struct Sequence
return
Get
(
Number
<
mSize
-
1
>
{});
}
__host__
__device__
static
constexpr
auto
PopFront
()
;
__host__
__device__
static
constexpr
auto
PopFront
()
{
return
sequence_pop_front
(
Type
{});
}
__host__
__device__
static
constexpr
auto
PopBack
()
;
__host__
__device__
static
constexpr
auto
PopBack
()
{
return
sequence_pop_back
(
Type
{});
}
template
<
index_t
...
Xs
>
__host__
__device__
static
constexpr
auto
PushFront
(
Sequence
<
Xs
...
>
)
...
...
@@ -126,7 +144,16 @@ struct Sequence
}
template
<
index_t
I
,
index_t
X
>
__host__
__device__
static
constexpr
auto
Modify
(
Number
<
I
>
,
Number
<
X
>
);
__host__
__device__
static
constexpr
auto
Modify
(
Number
<
I
>
,
Number
<
X
>
)
{
static_assert
(
I
<
GetSize
(),
"wrong!"
);
using
seq_split
=
sequence_split
<
Type
,
I
>
;
constexpr
auto
seq_left
=
typename
seq_split
::
SeqType0
{};
constexpr
auto
seq_right
=
typename
seq_split
::
SeqType1
{}.
PopFront
();
return
seq_left
.
PushBack
(
Number
<
X
>
{}).
PushBack
(
seq_right
);
}
template
<
class
F
>
__host__
__device__
static
constexpr
auto
Transform
(
F
f
)
...
...
@@ -283,7 +310,8 @@ template <class X2Y, class WorkingY2X, index_t XBegin, index_t XRemain>
struct
sequence_map_inverse_impl
{
private:
static
constexpr
auto
new_y2x
=
WorkingY2X
::
Modify
(
X2Y
{}[
XBegin
],
XBegin
);
static
constexpr
auto
new_y2x
=
WorkingY2X
::
Modify
(
X2Y
::
Get
(
Number
<
XBegin
>
{}),
Number
<
XBegin
>
{});
public:
using
type
=
...
...
@@ -417,8 +445,8 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
template
<
class
Seq
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Seq
)
{
static_assert
(
Seq
{}.
GetSize
()
>
0
,
"wrong! cannot pop an empty Sequence!"
);
return
sequence_pop_front
(
Seq
{}.
Reverse
()).
Reverse
();
static_assert
(
Seq
::
GetSize
()
>
0
,
"wrong! cannot pop an empty Sequence!"
);
return
sequence_pop_front
(
Seq
::
Reverse
()).
Reverse
();
}
template
<
class
F
,
index_t
...
Xs
>
...
...
@@ -458,37 +486,6 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
return
reverse_inclusive_scan_sequence
(
Seq
{}.
Reverse
(),
Reduce
{},
Number
<
Init
>
{}).
Reverse
();
}
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
PopFront
()
{
return
sequence_pop_front
(
Type
{});
}
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
PopBack
()
{
return
sequence_pop_back
(
Type
{});
}
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
Reverse
()
{
return
typename
sequence_reverse
<
Sequence
<
Is
...
>>::
type
{};
}
template
<
index_t
...
Is
>
template
<
index_t
I
,
index_t
X
>
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
Modify
(
Number
<
I
>
,
Number
<
X
>
)
{
static_assert
(
I
<
GetSize
(),
"wrong!"
);
using
seq_split
=
sequence_split
<
Type
,
I
>
;
constexpr
auto
seq_left
=
typename
seq_split
::
SeqType0
{};
constexpr
auto
seq_right
=
typename
seq_split
::
SeqType1
{}.
PopFront
();
return
seq_left
.
PushBack
(
Number
<
X
>
{}).
PushBack
(
seq_right
);
}
template
<
index_t
...
Xs
>
__host__
__device__
void
print_Sequence
(
const
char
*
s
,
Sequence
<
Xs
...
>
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment