Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e7f633c5
Commit
e7f633c5
authored
Sep 30, 2020
by
Chao Liu
Browse files
refactoring array, tuple
parent
ffa2c520
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
408 additions
and
24 deletions
+408
-24
composable_kernel/include/kernel_algorithm/dummy_dynamic_transform_v2.hpp
...l/include/kernel_algorithm/dummy_dynamic_transform_v2.hpp
+1
-4
composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp
...lude/tensor_description/dynamic_multi_index_transform.hpp
+94
-0
composable_kernel/include/tensor_description/multi_index.hpp
composable_kernel/include/tensor_description/multi_index.hpp
+3
-2
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+3
-0
composable_kernel/include/utility/statically_indexed_array.hpp
...sable_kernel/include/utility/statically_indexed_array.hpp
+239
-2
composable_kernel/include/utility/tuple.hpp
composable_kernel/include/utility/tuple.hpp
+67
-15
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+1
-1
No files found.
composable_kernel/include/kernel_algorithm/dummy_dynamic_transform_v2.hpp
View file @
e7f633c5
...
...
@@ -291,10 +291,7 @@ struct DummyDynamicTransform_v2_2
MultiIndex
<
2
>
idx
;
// initialize idx
for
(
index_t
i
=
0
;
i
<
2
;
++
i
)
{
idx
(
i
)
=
p_wei_global
[
get_thread_local_1d_id
()
+
i
];
}
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
idx
(
i
)
=
p_wei_global
[
get_thread_local_1d_id
()
+
i
];
});
#if 0
const index_t niter = p_wei_global[10];
...
...
composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp
View file @
e7f633c5
...
...
@@ -13,6 +13,22 @@ struct DynamicPassThrough
const
UpperIndex
up_lengths_
;
#if 0
__host__ __device__ explicit constexpr DynamicPassThrough(const DynamicPassThrough&) = default;
__host__ __device__ explicit constexpr DynamicPassThrough(DynamicPassThrough&&) = default;
#else
__host__
__device__
explicit
constexpr
DynamicPassThrough
(
const
DynamicPassThrough
&
other
)
:
up_lengths_
{
other
.
up_lengths_
}
{
}
__host__
__device__
explicit
constexpr
DynamicPassThrough
(
DynamicPassThrough
&&
other
)
:
up_lengths_
{
other
.
up_lengths_
}
{
}
#endif
__host__
__device__
explicit
constexpr
DynamicPassThrough
(
const
index_t
&
low_length
)
:
up_lengths_
{
make_multi_index
(
low_length
)}
{
...
...
@@ -72,6 +88,22 @@ struct DynamicLeftPad
const
UpperIndex
up_lengths_
;
const
index_t
left_pad_
;
#if 0
__host__ __device__ explicit constexpr DynamicLeftPad(const DynamicLeftPad&) = default;
__host__ __device__ explicit constexpr DynamicLeftPad(DynamicLeftPad&&) = default;
#else
__host__
__device__
explicit
constexpr
DynamicLeftPad
(
const
DynamicLeftPad
&
other
)
:
up_lengths_
{
other
.
up_lengths_
},
left_pad_
{
other
.
left_pad_
}
{
}
__host__
__device__
explicit
constexpr
DynamicLeftPad
(
DynamicLeftPad
&&
other
)
:
up_lengths_
{
other
.
up_lengths_
},
left_pad_
{
other
.
left_pad_
}
{
}
#endif
__host__
__device__
explicit
constexpr
DynamicLeftPad
(
const
index_t
&
low_length
,
const
index_t
&
left_pad
)
:
up_lengths_
{
make_multi_index
(
low_length
+
left_pad
)},
left_pad_
{
left_pad
}
...
...
@@ -135,6 +167,26 @@ struct DynamicRightPad
const
index_t
low_length_
;
const
index_t
right_pad_
;
#if 0
__host__ __device__ explicit constexpr DynamicRightPad(const DynamicRightPad&) = default;
__host__ __device__ explicit constexpr DynamicRightPad(DynamicRightPad&&) = default;
#else
__host__
__device__
explicit
constexpr
DynamicRightPad
(
const
DynamicRightPad
&
other
)
:
up_lengths_
{
other
.
up_lengths_
},
low_length_
{
other
.
low_length_
},
right_pad_
{
other
.
right_pad_
}
{
}
__host__
__device__
explicit
constexpr
DynamicRightPad
(
DynamicRightPad
&&
other
)
:
up_lengths_
{
other
.
up_lengths_
},
low_length_
{
other
.
low_length_
},
right_pad_
{
other
.
right_pad_
}
{
}
#endif
__host__
__device__
explicit
constexpr
DynamicRightPad
(
const
index_t
&
low_length
,
const
index_t
&
right_pad
)
:
up_lengths_
{
make_multi_index
(
low_length
+
right_pad
)},
...
...
@@ -203,6 +255,21 @@ struct DynamicEmbed
const
UpperIndex
up_lengths_
;
const
UpperIndex
coefficients_
;
#if 0
__host__ __device__ explicit constexpr DynamicEmbed(const DynamicEmbed&) = default;
__host__ __device__ explicit constexpr DynamicEmbed(DynamicEmbed&&) = default;
#else
__host__
__device__
explicit
constexpr
DynamicEmbed
(
const
DynamicEmbed
&
other
)
:
up_lengths_
{
other
.
up_lengths_
},
coefficients_
{
other
.
coefficients_
}
{
}
__host__
__device__
explicit
constexpr
DynamicEmbed
(
DynamicEmbed
&&
other
)
:
up_lengths_
{
other
.
up_lengths_
},
coefficients_
{
other
.
coefficients_
}
{
}
#endif
__host__
__device__
explicit
constexpr
DynamicEmbed
(
const
UpperIndex
&
up_lengths
,
const
UpperIndex
&
coefficients
)
:
up_lengths_
{
up_lengths
},
coefficients_
{
coefficients
}
...
...
@@ -210,6 +277,13 @@ struct DynamicEmbed
static_assert
(
UpperIndex
::
Size
()
==
NDimUp
,
"wrong! # of dimensions not consistent"
);
}
template
<
typename
UpperLengths
,
typename
Coefficients
>
__host__
__device__
explicit
constexpr
DynamicEmbed
(
const
UpperLengths
&
up_lengths
,
const
Coefficients
&
coefficients
)
:
up_lengths_
{
up_lengths
},
coefficients_
{
coefficients
}
{
}
__host__
__device__
explicit
constexpr
DynamicEmbed
()
:
up_lengths_
{
make_zero_multi_index
<
NDimUp
>
()},
coefficients_
{
make_zero_multi_index
<
NDimUp
>
()}
...
...
@@ -277,6 +351,26 @@ struct DynamicMerge
const
LowerIndex
low_lengths_scan_
;
const
UpperIndex
up_lengths_
;
#if 0
__host__ __device__ explicit constexpr DynamicMerge(const DynamicMerge&) = default;
__host__ __device__ explicit constexpr DynamicMerge(DynamicMerge&&) = default;
#else
__host__
__device__
explicit
constexpr
DynamicMerge
(
const
DynamicMerge
&
other
)
:
low_lengths_
{
other
.
low_lengths_
},
low_lengths_scan_
{
other
.
low_lengths_scan_
},
up_lengths_
{
other
.
up_lengths_
}
{
}
__host__
__device__
explicit
constexpr
DynamicMerge
(
DynamicMerge
&&
other
)
:
low_lengths_
{
other
.
low_lengths_
},
low_lengths_scan_
{
other
.
low_lengths_scan_
},
up_lengths_
{
other
.
up_lengths_
}
{
}
#endif
__host__
__device__
explicit
constexpr
DynamicMerge
(
const
LowerIndex
&
low_lengths
)
:
low_lengths_
{
low_lengths
},
low_lengths_scan_
{
reverse_exclusive_scan_on_array
(
...
...
composable_kernel/include/tensor_description/multi_index.hpp
View file @
e7f633c5
...
...
@@ -5,7 +5,7 @@
namespace
ck
{
#if
1 // dyanmically indexed array
#if
CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX
template
<
index_t
N
>
using
MultiIndex
=
Array
<
index_t
,
N
>
;
...
...
@@ -22,7 +22,8 @@ __host__ __device__ constexpr auto make_multi_index(Xs&&... xs)
return
make_array
<
const
index_t
>
(
std
::
forward
<
const
Xs
>
(
xs
)...);
}
#endif
#else // statically index array
#else
template
<
index_t
N
>
using
MultiIndex
=
StaticallyIndexedArray
<
index_t
,
N
>
;
...
...
composable_kernel/include/utility/config.amd.hpp.in
View file @
e7f633c5
...
...
@@ -8,6 +8,9 @@
// index type: unsigned or signed
#define CK_UNSIGNED_INDEX_TYPE 0
// multi index
#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 1
// device backend
#define CK_DEVICE_BACKEND_AMD 1
...
...
composable_kernel/include/utility/statically_indexed_array.hpp
View file @
e7f633c5
...
...
@@ -7,6 +7,7 @@
namespace
ck
{
#if 0
template <typename T, index_t NSize>
struct StaticallyIndexedArray
{
...
...
@@ -19,19 +20,57 @@ struct StaticallyIndexedArray<T, 0> : public Tuple<>
using base = Tuple<>;
__host__ __device__ explicit constexpr StaticallyIndexedArray() : base() {}
__host__ __device__ explicit constexpr StaticallyIndexedArray(const StaticallyIndexedArray&) =
default;
__host__
__device__ explicit constexpr StaticallyIndexedArray(StaticallyIndexedArray&&) = default;
};
template <typename T>
struct StaticallyIndexedArray<T, 1> : public Tuple<T>
{
using type = StaticallyIndexedArray;
using data_type = T;
using base = Tuple<T>;
static constexpr index_t nsize = base::Size();
template
<
typename
...
Ys
>
__host__ __device__ explicit constexpr StaticallyIndexedArray(const StaticallyIndexedArray&) =
default;
__host__
__device__ explicit constexpr StaticallyIndexedArray(StaticallyIndexedArray&&) = default;
template <typename Y>
__host__
__device__ explicit constexpr StaticallyIndexedArray(const StaticallyIndexedArray<Y, nsize>& y)
: base(static_cast<const Tuple<Y>&>(y))
{
}
template <typename Y>
__host__ __device__ explicit constexpr StaticallyIndexedArray(StaticallyIndexedArray<Y, nsize>&& y)
: base(static_cast<Tuple<Y>&&>(y))
{
}
#if 0
template <typename... Ys,
typename std::enable_if<sizeof...(Ys) == base::Size(),
bool>::type = false>
__host__ __device__ explicit constexpr StaticallyIndexedArray(Ys&&... ys)
: base(std::forward<Ys>(ys)...)
{
static_assert(sizeof...(Ys) == nsize, "wrong! inconsistent size");
}
#else
template <typename Y>
__host__ __device__ explicit constexpr StaticallyIndexedArray(Y&& y)
: base(std::forward<Y>(y))
{
}
#endif
};
template
<
typename
T
>
...
...
@@ -40,10 +79,32 @@ struct StaticallyIndexedArray<T, 2> : public Tuple<T, T>
using
data_type
=
T
;
using
base
=
Tuple
<
T
,
T
>
;
template
<
typename
...
Ys
>
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
template
<
typename
Y
>
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
<
Y
,
2
>&
y
)
:
base
(
static_cast
<
const
Tuple
<
Y
,
Y
>&>
(
y
))
{
}
template
<
typename
Y
>
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
<
Y
,
2
>&&
y
)
:
base
(
static_cast
<
Tuple
<
Y
,
Y
>&&>
(
y
))
{
}
template
<
typename
...
Ys
,
typename
std
::
enable_if
<
sizeof
...(
Ys
)
==
base
::
Size
(),
bool
>
::
type
=
false
>
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
Ys
&&
...
ys
)
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
static_assert
(
sizeof
...(
Ys
)
==
2
,
"wrong! inconsistent size"
);
}
};
...
...
@@ -58,6 +119,12 @@ struct StaticallyIndexedArray<T, 3> : public Tuple<T, T, T>
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -71,6 +138,12 @@ struct StaticallyIndexedArray<T, 4> : public Tuple<T, T, T, T>
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -84,6 +157,12 @@ struct StaticallyIndexedArray<T, 5> : public Tuple<T, T, T, T, T>
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -97,6 +176,12 @@ struct StaticallyIndexedArray<T, 6> : public Tuple<T, T, T, T, T, T>
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -110,6 +195,12 @@ struct StaticallyIndexedArray<T, 7> : public Tuple<T, T, T, T, T, T, T>
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -123,6 +214,12 @@ struct StaticallyIndexedArray<T, 8> : public Tuple<T, T, T, T, T, T, T, T>
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -136,6 +233,12 @@ struct StaticallyIndexedArray<T, 9> : public Tuple<T, T, T, T, T, T, T, T, T>
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -149,6 +252,12 @@ struct StaticallyIndexedArray<T, 10> : public Tuple<T, T, T, T, T, T, T, T, T, T
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -162,6 +271,12 @@ struct StaticallyIndexedArray<T, 11> : public Tuple<T, T, T, T, T, T, T, T, T, T
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -175,6 +290,12 @@ struct StaticallyIndexedArray<T, 12> : public Tuple<T, T, T, T, T, T, T, T, T, T
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -188,6 +309,12 @@ struct StaticallyIndexedArray<T, 13> : public Tuple<T, T, T, T, T, T, T, T, T, T
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -201,6 +328,12 @@ struct StaticallyIndexedArray<T, 14> : public Tuple<T, T, T, T, T, T, T, T, T, T
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -214,6 +347,12 @@ struct StaticallyIndexedArray<T, 15> : public Tuple<T, T, T, T, T, T, T, T, T, T
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -227,6 +366,12 @@ struct StaticallyIndexedArray<T, 16> : public Tuple<T, T, T, T, T, T, T, T, T, T
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -241,6 +386,12 @@ struct StaticallyIndexedArray<T, 17>
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -255,6 +406,12 @@ struct StaticallyIndexedArray<T, 18>
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -269,6 +426,12 @@ struct StaticallyIndexedArray<T, 19>
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -283,6 +446,12 @@ struct StaticallyIndexedArray<T, 20>
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -297,6 +466,12 @@ struct StaticallyIndexedArray<T, 21>
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
template
<
typename
T
>
...
...
@@ -311,7 +486,69 @@ struct StaticallyIndexedArray<T, 22>
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
const
StaticallyIndexedArray
&
)
=
default
;
__host__
__device__
explicit
constexpr
StaticallyIndexedArray
(
StaticallyIndexedArray
&&
)
=
default
;
};
#else
namespace
detail
{
template
<
typename
T
,
index_t
NSize
>
__host__
__device__
constexpr
auto
generate_same_type_tuple
()
{
return
generate_tuple
([](
auto
)
->
T
{
return
T
{};
},
Number
<
NSize
>
{});
}
template
<
typename
T
,
index_t
NSize
>
using
same_type_tuple
=
decltype
(
generate_same_type_tuple
<
T
,
NSize
>
());
}
// namespace detail
#if 0
template <typename T, index_t NSize>
struct StaticallyIndexedArray : public detail::same_type_tuple<T, NSize>
{
using type = StaticallyIndexedArray;
using data_type = T;
using base = detail::same_type_tuple<T, NSize>;
__host__ __device__ explicit constexpr StaticallyIndexedArray(const StaticallyIndexedArray&) =
default;
__host__
__device__ explicit constexpr StaticallyIndexedArray(StaticallyIndexedArray&&) = default;
template <typename Y>
__host__ __device__ explicit constexpr StaticallyIndexedArray(
const StaticallyIndexedArray<Y, NSize>& y)
: base(static_cast<const detail::same_type_tuple<Y, NSize>&>(y))
{
}
template <typename Y>
__host__
__device__ explicit constexpr StaticallyIndexedArray(StaticallyIndexedArray<Y, NSize>&& y)
: base(static_cast<detail::same_type_tuple<Y, NSize>&&>(y))
{
}
template <typename... Ys,
typename std::enable_if<sizeof...(Ys) == base::Size(), bool>::type = false>
__host__ __device__ explicit constexpr StaticallyIndexedArray(Ys&&... ys)
: base(std::forward<Ys>(ys)...)
{
static_assert(sizeof...(Ys) == NSize, "wrong! inconsistent size");
}
};
#else
template
<
typename
T
,
index_t
NSize
>
using
StaticallyIndexedArray
=
detail
::
same_type_tuple
<
T
,
NSize
>
;
#endif
#endif
template
<
typename
X
,
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_statically_indexed_array
(
const
X
&
x
,
const
Xs
&
...
xs
)
...
...
composable_kernel/include/utility/tuple.hpp
View file @
e7f633c5
...
...
@@ -19,14 +19,26 @@ struct TupleElement
{
__host__
__device__
explicit
constexpr
TupleElement
()
:
mData
()
{}
template
<
typename
T
>
__host__
__device__
explicit
constexpr
TupleElement
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
__host__
__device__
explicit
constexpr
TupleElement
(
const
TupleElement
&
)
=
default
;
__host__
__device__
explicit
constexpr
TupleElement
(
TupleElement
&&
)
=
default
;
template
<
typename
UData
>
__host__
__device__
explicit
constexpr
TupleElement
(
const
TupleElement
<
Key
,
UData
>&
te
)
:
mData
(
static_cast
<
const
UData
&>
(
te
.
mData
))
{
}
__host__
__device__
explicit
constexpr
TupleElement
(
const
TupleElement
&
)
=
default
;
template
<
typename
UData
>
__host__
__device__
explicit
constexpr
TupleElement
(
TupleElement
<
Key
,
UData
>&&
te
)
:
mData
(
static_cast
<
UData
&&>
(
te
.
mData
))
{
}
__host__
__device__
explicit
constexpr
TupleElement
(
TupleElement
&&
)
=
default
;
template
<
typename
T
>
__host__
__device__
explicit
constexpr
TupleElement
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
{
}
Data
mData
;
};
...
...
@@ -34,7 +46,7 @@ struct TupleElement
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
const
Data
&
get_tuple_element
(
const
TupleElement
<
Key
,
Data
>&
x
)
{
return
x
.
mData
;
return
static_cast
<
const
Data
&>
(
x
.
mData
)
;
}
template
<
typename
Key
,
typename
Data
>
...
...
@@ -43,14 +55,12 @@ __host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x
return
x
.
mData
;
}
#if 0
// TODO: not sure the use of reference is correct
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
Data
&&
get_tuple_element
(
TupleElement
<
Key
,
Data
>&&
x
)
{
return
static_cast
<
Data
&&>
(
x
.
mData
);
}
#endif
template
<
typename
Indices
,
typename
...
Xs
>
struct
TupleImpl
;
...
...
@@ -63,7 +73,25 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
Xs
),
"wrong! inconsistent size"
);
}
template
<
typename
...
Ys
>
__host__
__device__
explicit
constexpr
TupleImpl
(
const
TupleImpl
&
)
=
default
;
__host__
__device__
explicit
constexpr
TupleImpl
(
TupleImpl
&&
)
=
default
;
template
<
index_t
...
Js
,
typename
...
Ys
>
__host__
__device__
explicit
constexpr
TupleImpl
(
const
TupleImpl
<
Sequence
<
Js
...
>
,
Ys
...
>&
y
)
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
static_cast
<
const
TupleElement
<
TupleElementKey
<
Js
>
,
Ys
>&>
(
y
))...
{
}
template
<
index_t
...
Js
,
typename
...
Ys
>
__host__
__device__
explicit
constexpr
TupleImpl
(
TupleImpl
<
Sequence
<
Js
...
>
,
Ys
...
>&&
y
)
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
static_cast
<
TupleElement
<
TupleElementKey
<
Js
>
,
Ys
>&&>
(
y
))...
{
}
template
<
typename
...
Ys
,
typename
std
::
enable_if
<
sizeof
...(
Ys
)
>
=
1
,
bool
>::
type
=
false
>
__host__
__device__
explicit
constexpr
TupleImpl
(
Ys
&&
...
ys
)
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Ys
>
(
ys
))...
{
...
...
@@ -71,10 +99,6 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
"wrong! inconsistent size"
);
}
__host__
__device__
explicit
constexpr
TupleImpl
(
const
TupleImpl
&
)
=
default
;
__host__
__device__
explicit
constexpr
TupleImpl
(
TupleImpl
&&
)
=
default
;
__host__
__device__
static
constexpr
index_t
Size
()
{
return
sizeof
...(
Xs
);
}
template
<
index_t
I
>
...
...
@@ -98,14 +122,42 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
using
base
=
detail
::
TupleImpl
<
typename
arithmetic_sequence_gen
<
0
,
sizeof
...(
Xs
),
1
>::
type
,
Xs
...
>
;
__host__
__device__
explicit
constexpr
Tuple
()
:
base
()
{}
__host__
__device__
explicit
constexpr
Tuple
(
const
Tuple
&
)
=
default
;
__host__
__device__
explicit
constexpr
Tuple
(
Tuple
&&
)
=
default
;
#if 0
template <typename... Ys,
typename std::enable_if<sizeof...(Ys) == sizeof...(Xs), bool>::type = false>
#else
template
<
typename
...
Ys
>
__host__
__device__
explicit
constexpr
Tuple
(
Ys
&&
...
ys
)
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
#endif
__host__
__device__
explicit
constexpr
Tuple
(
const
Tuple
<
Ys
...
>&
y
)
:
base
(
static_cast
<
const
detail
::
TupleImpl
<
typename
arithmetic_sequence_gen
<
0
,
sizeof
...(
Ys
),
1
>::
type
,
Ys
...
>&>
(
y
))
{
}
__host__
__device__
explicit
constexpr
Tuple
(
const
Tuple
&
)
=
default
;
#if 0
template <typename... Ys,
typename std::enable_if<sizeof...(Ys) == sizeof...(Xs), bool>::type = false>
#else
template
<
typename
...
Ys
>
#endif
__host__
__device__
explicit
constexpr
Tuple
(
Tuple
<
Ys
...
>&&
y
)
:
base
(
static_cast
<
detail
::
TupleImpl
<
typename
arithmetic_sequence_gen
<
0
,
sizeof
...(
Ys
),
1
>::
type
,
Ys
...
>&&>
(
y
))
{
}
__host__
__device__
explicit
constexpr
Tuple
(
Tuple
&&
)
=
default
;
template
<
typename
...
Ys
,
typename
std
::
enable_if
<
sizeof
...(
Ys
)
>
=
1
,
bool
>::
type
=
false
>
__host__
__device__
explicit
constexpr
Tuple
(
Ys
&&
...
ys
)
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
{
}
__host__
__device__
static
constexpr
index_t
Size
()
{
return
sizeof
...(
Xs
);
}
...
...
driver/src/conv_driver.cpp
View file @
e7f633c5
...
...
@@ -549,7 +549,7 @@ int main(int argc, char* argv[])
#endif
}
#if
0
#if
1
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment