Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3d51e246
Commit
3d51e246
authored
Aug 04, 2023
by
Rostyslav Geyyer
Browse files
Update vector_type implementation
parent
a92772bf
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
143 additions
and
44 deletions
+143
-44
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+143
-44
No files found.
include/ck/utility/data_type.hpp
View file @
3d51e246
...
@@ -32,10 +32,22 @@ struct bf8_t
...
@@ -32,10 +32,22 @@ struct bf8_t
__host__
__device__
bf8_t
()
=
default
;
__host__
__device__
bf8_t
()
=
default
;
};
};
template
<
typename
T
>
inline
__host__
__device__
constexpr
auto
is_native
()
{
return
std
::
is_same
<
T
,
half_t
>::
value
||
std
::
is_same
<
T
,
float
>::
value
||
std
::
is_same
<
T
,
double
>::
value
||
std
::
is_same
<
T
,
bhalf_t
>::
value
||
std
::
is_same
<
T
,
int32_t
>::
value
||
std
::
is_same
<
T
,
int8_t
>::
value
;
}
// vector_type
// vector_type
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
struct
vector_type
;
struct
vector_type
;
// // non_native_vector_type
// template <typename T, index_t N>
// struct non_native_vector_type;
// Caution: DO NOT REMOVE
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
...
@@ -168,6 +180,13 @@ struct scalar_type<f8_t>
...
@@ -168,6 +180,13 @@ struct scalar_type<f8_t>
static
constexpr
index_t
vector_size
=
1
;
static
constexpr
index_t
vector_size
=
1
;
};
};
// utility function for non native vector type
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
{
// Precondition: x > 1.
return
x
>
1u
?
(
1u
<<
(
32u
-
__builtin_clz
(
x
-
1u
)))
:
x
;
}
//
//
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
struct
vector_type
<
T
,
1
>
...
@@ -206,7 +225,10 @@ template <typename T>
...
@@ -206,7 +225,10 @@ template <typename T>
struct
vector_type
<
T
,
2
>
struct
vector_type
<
T
,
2
>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
using
ext_vect_t
=
conditional_t
<
is_native
<
T
>
(),
T
,
uint32_t
>
;
using
d2_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
2
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
2
))))
>
;
using
type
=
d2_t
;
using
type
=
d2_t
;
...
@@ -256,8 +278,13 @@ template <typename T>
...
@@ -256,8 +278,13 @@ template <typename T>
struct
vector_type
<
T
,
4
>
struct
vector_type
<
T
,
4
>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
using
ext_vect_t
=
conditional_t
<
is_native
<
T
>
(),
T
,
uint32_t
>
;
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
using
d2_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
2
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
2
))))
>
;
using
d4_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
4
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
4
))))
>
;
using
type
=
d4_t
;
using
type
=
d4_t
;
...
@@ -318,9 +345,16 @@ template <typename T>
...
@@ -318,9 +345,16 @@ template <typename T>
struct
vector_type
<
T
,
8
>
struct
vector_type
<
T
,
8
>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
using
ext_vect_t
=
conditional_t
<
is_native
<
T
>
(),
T
,
uint32_t
>
;
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
using
d2_t
=
conditional_t
<
is_native
<
T
>
(),
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
ext_vect_t
__attribute__
((
ext_vector_type
(
2
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
2
))))
>
;
using
d4_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
4
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
4
))))
>
;
using
d8_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
8
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
8
))))
>
;
using
type
=
d8_t
;
using
type
=
d8_t
;
...
@@ -392,10 +426,19 @@ template <typename T>
...
@@ -392,10 +426,19 @@ template <typename T>
struct
vector_type
<
T
,
16
>
struct
vector_type
<
T
,
16
>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
using
ext_vect_t
=
conditional_t
<
is_native
<
T
>
(),
T
,
uint32_t
>
;
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
using
d2_t
=
conditional_t
<
is_native
<
T
>
(),
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
ext_vect_t
__attribute__
((
ext_vector_type
(
2
))),
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
2
))))
>
;
using
d4_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
4
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
4
))))
>
;
using
d8_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
8
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
8
))))
>
;
using
d16_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
16
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
16
))))
>
;
using
type
=
d16_t
;
using
type
=
d16_t
;
...
@@ -478,11 +521,22 @@ template <typename T>
...
@@ -478,11 +521,22 @@ template <typename T>
struct
vector_type
<
T
,
32
>
struct
vector_type
<
T
,
32
>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
using
ext_vect_t
=
conditional_t
<
is_native
<
T
>
(),
T
,
uint32_t
>
;
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
using
d2_t
=
conditional_t
<
is_native
<
T
>
(),
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
ext_vect_t
__attribute__
((
ext_vector_type
(
2
))),
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
2
))))
>
;
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
using
d4_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
4
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
4
))))
>
;
using
d8_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
8
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
8
))))
>
;
using
d16_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
16
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
16
))))
>
;
using
d32_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
32
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
32
))))
>
;
using
type
=
d32_t
;
using
type
=
d32_t
;
...
@@ -574,12 +628,25 @@ template <typename T>
...
@@ -574,12 +628,25 @@ template <typename T>
struct
vector_type
<
T
,
64
>
struct
vector_type
<
T
,
64
>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
using
ext_vect_t
=
conditional_t
<
is_native
<
T
>
(),
T
,
uint32_t
>
;
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
using
d2_t
=
conditional_t
<
is_native
<
T
>
(),
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
ext_vect_t
__attribute__
((
ext_vector_type
(
2
))),
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
2
))))
>
;
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
using
d4_t
=
conditional_t
<
is_native
<
T
>
(),
typedef
T
d64_t
__attribute__
((
ext_vector_type
(
64
)));
ext_vect_t
__attribute__
((
ext_vector_type
(
4
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
4
))))
>
;
using
d8_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
8
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
8
))))
>
;
using
d16_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
16
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
16
))))
>
;
using
d32_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
32
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
32
))))
>
;
using
d64_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
64
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
64
))))
>
;
using
type
=
d64_t
;
using
type
=
d64_t
;
...
@@ -682,13 +749,28 @@ template <typename T>
...
@@ -682,13 +749,28 @@ template <typename T>
struct
vector_type
<
T
,
128
>
struct
vector_type
<
T
,
128
>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
using
ext_vect_t
=
conditional_t
<
is_native
<
T
>
(),
T
,
uint32_t
>
;
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
using
d2_t
=
conditional_t
<
is_native
<
T
>
(),
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
ext_vect_t
__attribute__
((
ext_vector_type
(
2
))),
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
2
))))
>
;
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
using
d4_t
=
conditional_t
<
is_native
<
T
>
(),
typedef
T
d64_t
__attribute__
((
ext_vector_type
(
64
)));
ext_vect_t
__attribute__
((
ext_vector_type
(
4
))),
typedef
T
d128_t
__attribute__
((
ext_vector_type
(
128
)));
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
4
))))
>
;
using
d8_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
8
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
8
))))
>
;
using
d16_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
16
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
16
))))
>
;
using
d32_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
32
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
32
))))
>
;
using
d64_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
64
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
64
))))
>
;
using
d128_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
128
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
128
))))
>
;
using
type
=
d128_t
;
using
type
=
d128_t
;
...
@@ -800,14 +882,31 @@ template <typename T>
...
@@ -800,14 +882,31 @@ template <typename T>
struct
vector_type
<
T
,
256
>
struct
vector_type
<
T
,
256
>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
using
ext_vect_t
=
conditional_t
<
is_native
<
T
>
(),
T
,
uint32_t
>
;
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
using
d2_t
=
conditional_t
<
is_native
<
T
>
(),
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
ext_vect_t
__attribute__
((
ext_vector_type
(
2
))),
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
2
))))
>
;
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
using
d4_t
=
conditional_t
<
is_native
<
T
>
(),
typedef
T
d64_t
__attribute__
((
ext_vector_type
(
64
)));
ext_vect_t
__attribute__
((
ext_vector_type
(
4
))),
typedef
T
d128_t
__attribute__
((
ext_vector_type
(
128
)));
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
4
))))
>
;
typedef
T
d256_t
__attribute__
((
ext_vector_type
(
256
)));
using
d8_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
8
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
8
))))
>
;
using
d16_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
16
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
16
))))
>
;
using
d32_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
32
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
32
))))
>
;
using
d64_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
64
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
64
))))
>
;
using
d128_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
128
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
128
))))
>
;
using
d256_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
256
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
256
))))
>
;
using
type
=
d256_t
;
using
type
=
d256_t
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment