Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ffa2c520
Commit
ffa2c520
authored
Sep 29, 2020
by
Chao Liu
Browse files
refactoring tuple
parent
6cd94d98
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
42 additions
and
17 deletions
+42
-17
composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp
.../include/tensor_description/dynamic_tensor_descriptor.hpp
+1
-1
composable_kernel/include/tensor_description/dynamic_tensor_descriptor_v2.hpp
...clude/tensor_description/dynamic_tensor_descriptor_v2.hpp
+3
-3
composable_kernel/include/tensor_description/multi_index.hpp
composable_kernel/include/tensor_description/multi_index.hpp
+2
-2
composable_kernel/include/utility/functional4.hpp
composable_kernel/include/utility/functional4.hpp
+15
-9
composable_kernel/include/utility/tuple.hpp
composable_kernel/include/utility/tuple.hpp
+18
-0
composable_kernel/include/utility/tuple_helper.hpp
composable_kernel/include/utility/tuple_helper.hpp
+3
-2
No files found.
composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp
View file @
ffa2c520
...
@@ -21,7 +21,7 @@ struct DynamicNativeTensorDescriptor
...
@@ -21,7 +21,7 @@ struct DynamicNativeTensorDescriptor
}
}
__host__
__device__
explicit
constexpr
DynamicNativeTensorDescriptor
()
__host__
__device__
explicit
constexpr
DynamicNativeTensorDescriptor
()
:
lengths_
{
make_zero_
array
<
index
_t
,
NDim
>
()},
strides_
{
make_zero_
array
<
index
_t
,
NDim
>
()}
:
lengths_
{
make_zero_
multi_
index
<
NDim
>
()},
strides_
{
make_zero_
multi_
index
<
NDim
>
()}
{
{
}
}
...
...
composable_kernel/include/tensor_description/dynamic_tensor_descriptor_v2.hpp
View file @
ffa2c520
...
@@ -408,13 +408,13 @@ transform_dynamic_tensor_descriptor_v2(const OldTensorDescriptor& old_tensor_des
...
@@ -408,13 +408,13 @@ transform_dynamic_tensor_descriptor_v2(const OldTensorDescriptor& old_tensor_des
unordered_new_visible_dim_hidden_ids
.
ReorderGivenOld2New
(
new_visible_dim_unordered2ordered
);
unordered_new_visible_dim_hidden_ids
.
ReorderGivenOld2New
(
new_visible_dim_unordered2ordered
);
// put everything together
// put everything together
const
auto
all_transforms
=
merge_
tuple
s
(
old_tensor_desc
.
GetTransforms
(),
new_transforms
);
const
auto
all_transforms
=
tuple
_cat
(
old_tensor_desc
.
GetTransforms
(),
new_transforms
);
constexpr
auto
all_low_dim_hidden_idss
=
constexpr
auto
all_low_dim_hidden_idss
=
merge_
tuple
s
(
OldTensorDescriptor
::
GetLowerDimensionIdss
(),
low_dim_hidden_idss
);
tuple
_cat
(
OldTensorDescriptor
::
GetLowerDimensionIdss
(),
low_dim_hidden_idss
);
constexpr
auto
all_up_dim_hidden_idss
=
constexpr
auto
all_up_dim_hidden_idss
=
merge_
tuple
s
(
OldTensorDescriptor
::
GetUpperDimensionIdss
(),
up_dim_hidden_idss
);
tuple
_cat
(
OldTensorDescriptor
::
GetUpperDimensionIdss
(),
up_dim_hidden_idss
);
return
DynamicTensorDescriptor_v2
<
decltype
(
all_transforms
),
return
DynamicTensorDescriptor_v2
<
decltype
(
all_transforms
),
decltype
(
all_low_dim_hidden_idss
),
decltype
(
all_low_dim_hidden_idss
),
...
...
composable_kernel/include/tensor_description/multi_index.hpp
View file @
ffa2c520
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
namespace
ck
{
namespace
ck
{
#if 1
#if 1
// dyanmically indexed array
template
<
index_t
N
>
template
<
index_t
N
>
using
MultiIndex
=
Array
<
index_t
,
N
>
;
using
MultiIndex
=
Array
<
index_t
,
N
>
;
...
@@ -22,7 +22,7 @@ __host__ __device__ constexpr auto make_multi_index(Xs&&... xs)
...
@@ -22,7 +22,7 @@ __host__ __device__ constexpr auto make_multi_index(Xs&&... xs)
return
make_array
<
const
index_t
>
(
std
::
forward
<
const
Xs
>
(
xs
)...);
return
make_array
<
const
index_t
>
(
std
::
forward
<
const
Xs
>
(
xs
)...);
}
}
#endif
#endif
#else
#else
// statically index array
template
<
index_t
N
>
template
<
index_t
N
>
using
MultiIndex
=
StaticallyIndexedArray
<
index_t
,
N
>
;
using
MultiIndex
=
StaticallyIndexedArray
<
index_t
,
N
>
;
...
...
composable_kernel/include/utility/functional4.hpp
View file @
ffa2c520
...
@@ -16,9 +16,9 @@ template <index_t... Is>
...
@@ -16,9 +16,9 @@ template <index_t... Is>
struct
unpack_impl
<
Sequence
<
Is
...
>>
struct
unpack_impl
<
Sequence
<
Is
...
>>
{
{
template
<
typename
F
,
typename
X
>
template
<
typename
F
,
typename
X
>
__host__
__device__
constexpr
auto
operator
()(
F
f
,
const
X
&
x
)
const
__host__
__device__
constexpr
auto
operator
()(
F
&&
f
,
X
&
&
x
)
const
{
{
return
f
(
x
.
At
(
Number
<
Is
>
{})...);
return
std
::
forward
<
F
>
(
f
)(
std
::
forward
<
X
>
(
x
)
.
At
(
Number
<
Is
>
{})...);
}
}
};
};
...
@@ -30,26 +30,32 @@ template <index_t... Is, index_t... Js>
...
@@ -30,26 +30,32 @@ template <index_t... Is, index_t... Js>
struct
unpack2_impl
<
Sequence
<
Is
...
>
,
Sequence
<
Js
...
>>
struct
unpack2_impl
<
Sequence
<
Is
...
>
,
Sequence
<
Js
...
>>
{
{
template
<
typename
F
,
typename
X
,
typename
Y
>
template
<
typename
F
,
typename
X
,
typename
Y
>
__host__
__device__
constexpr
auto
operator
()(
F
f
,
const
X
&
x
,
const
Y
&
y
)
const
__host__
__device__
constexpr
auto
operator
()(
F
&&
f
,
X
&
&
x
,
Y
&
&
y
)
const
{
{
return
f
(
x
.
At
(
Number
<
Is
>
{})...,
y
.
At
(
Number
<
Js
>
{})...);
return
std
::
forward
<
F
>
(
f
)(
std
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...,
std
::
forward
<
Y
>
(
y
).
At
(
Number
<
Js
>
{})...);
}
}
};
};
}
// namespace detail
}
// namespace detail
template
<
typename
F
,
typename
X
>
template
<
typename
F
,
typename
X
>
__host__
__device__
constexpr
auto
unpack
(
F
f
,
const
X
&
x
)
__host__
__device__
constexpr
auto
unpack
(
F
&&
f
,
X
&
&
x
)
{
{
return
detail
::
unpack_impl
<
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
>
{}(
f
,
x
);
using
X_
=
remove_reference_t
<
X
>
;
return
detail
::
unpack_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
Size
(),
1
>::
type
>
{}(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
X
>
(
x
));
}
}
// TODO: properly implement unpack that takes any number of containers
// TODO: properly implement unpack that takes any number of containers
template
<
typename
F
,
typename
X
,
typename
Y
>
template
<
typename
F
,
typename
X
,
typename
Y
>
__host__
__device__
constexpr
auto
unpack
(
F
f
,
const
X
&
x
,
const
Y
&
y
)
__host__
__device__
constexpr
auto
unpack
(
F
&&
f
,
X
&
&
x
,
Y
&
&
y
)
{
{
return
detail
::
unpack2_impl
<
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
,
using
X_
=
remove_reference_t
<
X
>
;
typename
arithmetic_sequence_gen
<
0
,
Y
::
Size
(),
1
>::
type
>
{}(
f
,
x
,
y
);
using
Y_
=
remove_reference_t
<
Y
>
;
return
detail
::
unpack2_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
Size
(),
1
>::
type
,
typename
arithmetic_sequence_gen
<
0
,
Y_
::
Size
(),
1
>::
type
>
{}(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Y
>
(
y
));
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/utility/tuple.hpp
View file @
ffa2c520
...
@@ -24,6 +24,10 @@ struct TupleElement
...
@@ -24,6 +24,10 @@ struct TupleElement
{
{
}
}
__host__
__device__
explicit
constexpr
TupleElement
(
const
TupleElement
&
)
=
default
;
__host__
__device__
explicit
constexpr
TupleElement
(
TupleElement
&&
)
=
default
;
Data
mData
;
Data
mData
;
};
};
...
@@ -39,11 +43,14 @@ __host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x
...
@@ -39,11 +43,14 @@ __host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x
return
x
.
mData
;
return
x
.
mData
;
}
}
#if 0
// TODO: not sure the use of reference is correct
template <typename Key, typename Data>
template <typename Key, typename Data>
__host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x)
__host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x)
{
{
return static_cast<Data&&>(x.mData);
return static_cast<Data&&>(x.mData);
}
}
#endif
template
<
typename
Indices
,
typename
...
Xs
>
template
<
typename
Indices
,
typename
...
Xs
>
struct
TupleImpl
;
struct
TupleImpl
;
...
@@ -53,14 +60,21 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
...
@@ -53,14 +60,21 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
{
{
__host__
__device__
explicit
constexpr
TupleImpl
()
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
()...
__host__
__device__
explicit
constexpr
TupleImpl
()
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
()...
{
{
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
Xs
),
"wrong! inconsistent size"
);
}
}
template
<
typename
...
Ys
>
template
<
typename
...
Ys
>
__host__
__device__
explicit
constexpr
TupleImpl
(
Ys
&&
...
ys
)
__host__
__device__
explicit
constexpr
TupleImpl
(
Ys
&&
...
ys
)
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Ys
>
(
ys
))...
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Ys
>
(
ys
))...
{
{
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Is
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
}
}
__host__
__device__
explicit
constexpr
TupleImpl
(
const
TupleImpl
&
)
=
default
;
__host__
__device__
explicit
constexpr
TupleImpl
(
TupleImpl
&&
)
=
default
;
__host__
__device__
static
constexpr
index_t
Size
()
{
return
sizeof
...(
Xs
);
}
__host__
__device__
static
constexpr
index_t
Size
()
{
return
sizeof
...(
Xs
);
}
template
<
index_t
I
>
template
<
index_t
I
>
...
@@ -89,6 +103,10 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
...
@@ -89,6 +103,10 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
{
{
}
}
__host__
__device__
explicit
constexpr
Tuple
(
const
Tuple
&
)
=
default
;
__host__
__device__
explicit
constexpr
Tuple
(
Tuple
&&
)
=
default
;
__host__
__device__
static
constexpr
index_t
Size
()
{
return
sizeof
...(
Xs
);
}
__host__
__device__
static
constexpr
index_t
Size
()
{
return
sizeof
...(
Xs
);
}
template
<
index_t
I
>
template
<
index_t
I
>
...
...
composable_kernel/include/utility/tuple_helper.hpp
View file @
ffa2c520
...
@@ -13,9 +13,10 @@ __host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
...
@@ -13,9 +13,10 @@ __host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
}
}
template
<
typename
...
Tuples
>
template
<
typename
...
Tuples
>
__host__
__device__
constexpr
auto
merge_
tuple
s
(
Tuples
...
tuples
)
__host__
__device__
constexpr
auto
tuple
_cat
(
Tuples
&&
...
tuples
)
{
{
return
unpack
([
&
tuples
...](
auto
...
xs
)
{
return
make_tuple
(
xs
...);
},
tuples
...);
return
unpack
([
&
](
auto
&&
...
xs
)
{
return
make_tuple
(
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
},
std
::
forward
<
Tuples
>
(
tuples
)...);
}
}
namespace
detail
{
namespace
detail
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment