Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2776c177
"driver/include/host_tensor.hpp" did not exist on "2603bb0fe3882ad75bd237db89a953e067a9c05b"
Commit
2776c177
authored
Aug 29, 2023
by
Rostyslav Geyyer
Browse files
Add bf8, use BitInt types
parent
9967360c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
94 additions
and
179 deletions
+94
-179
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+94
-179
No files found.
include/ck/utility/data_type.hpp
View file @
2776c177
...
...
@@ -12,25 +12,8 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
#endif
struct
f8_t
{
uint8_t
data
;
using
type
=
f8_t
;
using
data_type
=
uint8_t
;
__host__
__device__
f8_t
()
=
default
;
__host__
__device__
f8_t
(
uint8_t
init
);
};
struct
bf8_t
{
uint8_t
data
;
using
type
=
bf8_t
;
using
data_type
=
uint8_t
;
__host__
__device__
bf8_t
()
=
default
;
};
using
f8_t
=
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
template
<
typename
T
>
inline
__host__
__device__
constexpr
auto
is_native
()
...
...
@@ -44,10 +27,6 @@ inline __host__ __device__ constexpr auto is_native()
template
<
typename
T
,
index_t
N
>
struct
vector_type
;
// // non_native_vector_type
// template <typename T, index_t N>
// struct non_native_vector_type;
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
...
...
@@ -180,14 +159,13 @@ struct scalar_type<f8_t>
static
constexpr
index_t
vector_size
=
1
;
};
// utility function for non native vector type
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
template
<
>
struct
scalar_type
<
bf8_t
>
{
// Precondition: x > 1.
return
x
>
1u
?
(
1u
<<
(
32u
-
__builtin_clz
(
x
-
1u
)))
:
x
;
}
using
type
=
bf8_t
;
static
constexpr
index_t
vector_size
=
1
;
}
;
//
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
{
...
...
@@ -224,11 +202,8 @@ struct vector_type<T, 1>
template
<
typename
T
>
struct
vector_type
<
T
,
2
>
{
using
d1_t
=
T
;
using
ext_vect_t
=
conditional_t
<
is_native
<
T
>
(),
T
,
uint32_t
>
;
using
d2_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
2
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
2
))))
>
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
using
type
=
d2_t
;
...
...
@@ -277,14 +252,9 @@ struct vector_type<T, 2>
template
<
typename
T
>
struct
vector_type
<
T
,
4
>
{
using
d1_t
=
T
;
using
ext_vect_t
=
conditional_t
<
is_native
<
T
>
(),
T
,
uint32_t
>
;
using
d2_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
2
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
2
))))
>
;
using
d4_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
4
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
4
))))
>
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
using
type
=
d4_t
;
...
...
@@ -344,17 +314,10 @@ struct vector_type<T, 4>
template
<
typename
T
>
struct
vector_type
<
T
,
8
>
{
using
d1_t
=
T
;
using
ext_vect_t
=
conditional_t
<
is_native
<
T
>
(),
T
,
uint32_t
>
;
using
d2_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
2
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
2
))))
>
;
using
d4_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
4
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
4
))))
>
;
using
d8_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
8
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
8
))))
>
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
using
type
=
d8_t
;
...
...
@@ -425,20 +388,11 @@ struct vector_type<T, 8>
template
<
typename
T
>
struct
vector_type
<
T
,
16
>
{
using
d1_t
=
T
;
using
ext_vect_t
=
conditional_t
<
is_native
<
T
>
(),
T
,
uint32_t
>
;
using
d2_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
2
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
2
))))
>
;
using
d4_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
4
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
4
))))
>
;
using
d8_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
8
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
8
))))
>
;
using
d16_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
16
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
16
))))
>
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
using
type
=
d16_t
;
...
...
@@ -520,23 +474,12 @@ struct vector_type<T, 16>
template
<
typename
T
>
struct
vector_type
<
T
,
32
>
{
using
d1_t
=
T
;
using
ext_vect_t
=
conditional_t
<
is_native
<
T
>
(),
T
,
uint32_t
>
;
using
d2_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
2
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
2
))))
>
;
using
d4_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
4
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
4
))))
>
;
using
d8_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
8
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
8
))))
>
;
using
d16_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
16
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
16
))))
>
;
using
d32_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
32
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
32
))))
>
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
using
type
=
d32_t
;
...
...
@@ -627,26 +570,13 @@ struct vector_type<T, 32>
template
<
typename
T
>
struct
vector_type
<
T
,
64
>
{
using
d1_t
=
T
;
using
ext_vect_t
=
conditional_t
<
is_native
<
T
>
(),
T
,
uint32_t
>
;
using
d2_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
2
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
2
))))
>
;
using
d4_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
4
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
4
))))
>
;
using
d8_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
8
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
8
))))
>
;
using
d16_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
16
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
16
))))
>
;
using
d32_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
32
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
32
))))
>
;
using
d64_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
64
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
64
))))
>
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T
d64_t
__attribute__
((
ext_vector_type
(
64
)));
using
type
=
d64_t
;
...
...
@@ -748,29 +678,14 @@ struct vector_type<T, 64>
template
<
typename
T
>
struct
vector_type
<
T
,
128
>
{
using
d1_t
=
T
;
using
ext_vect_t
=
conditional_t
<
is_native
<
T
>
(),
T
,
uint32_t
>
;
using
d2_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
2
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
2
))))
>
;
using
d4_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
4
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
4
))))
>
;
using
d8_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
8
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
8
))))
>
;
using
d16_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
16
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
16
))))
>
;
using
d32_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
32
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
32
))))
>
;
using
d64_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
64
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
64
))))
>
;
using
d128_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
128
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
128
))))
>
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T
d64_t
__attribute__
((
ext_vector_type
(
64
)));
typedef
T
d128_t
__attribute__
((
ext_vector_type
(
128
)));
using
type
=
d128_t
;
...
...
@@ -881,32 +796,15 @@ struct vector_type<T, 128>
template
<
typename
T
>
struct
vector_type
<
T
,
256
>
{
using
d1_t
=
T
;
using
ext_vect_t
=
conditional_t
<
is_native
<
T
>
(),
T
,
uint32_t
>
;
using
d2_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
2
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
2
))))
>
;
using
d4_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
4
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
4
))))
>
;
using
d8_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
8
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
8
))))
>
;
using
d16_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
16
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
16
))))
>
;
using
d32_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
32
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
32
))))
>
;
using
d64_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
64
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
64
))))
>
;
using
d128_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
128
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
128
))))
>
;
using
d256_t
=
conditional_t
<
is_native
<
T
>
(),
ext_vect_t
__attribute__
((
ext_vector_type
(
256
))),
ext_vect_t
__attribute__
((
ext_vector_type
(
next_pow2
(
256
))))
>
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T
d64_t
__attribute__
((
ext_vector_type
(
64
)));
typedef
T
d128_t
__attribute__
((
ext_vector_type
(
128
)));
typedef
T
d256_t
__attribute__
((
ext_vector_type
(
256
)));
using
type
=
d256_t
;
...
...
@@ -1077,6 +975,14 @@ using f8x16_t = typename vector_type<f8_t, 16>::type;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
// bf8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
template
<
typename
T
>
struct
NumericLimits
{
...
...
@@ -1126,38 +1032,47 @@ struct NumericLimits<int4_t>
template
<
>
struct
NumericLimits
<
f8_t
>
{
// negative zero nan mode with exp bias = 8
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
static
constexpr
uint8_t
binary_max
=
0x7
7
;
// 0b0111
0
111
static
constexpr
uint8_t
binary_lowest
=
0xF
7
;
// 0b1111
0
111
static
constexpr
uint8_t
binary_max
=
0x7
F
;
// 0b0111
1
111
static
constexpr
uint8_t
binary_lowest
=
0xF
F
;
// 0b1111
1
111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__
__device__
static
f8_t
Min
()
{
f8_t
x
;
x
.
data
=
binary_min
;
return
x
;
}
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
f8_t
(
binary_min
);
}
__host__
__device__
static
f8_t
Max
()
{
f8_t
x
;
x
.
data
=
binary_max
;
return
x
;
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
f8_t
(
binary_max
);
}
__host__
__device__
static
f8_t
Lowest
()
{
f8_t
x
;
x
.
data
=
binary_lowest
;
return
x
;
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
f8_t
(
binary_lowest
);
}
__host__
__device__
static
f8_t
QuietNaN
()
{
f8_t
x
;
x
.
data
=
binary_qnan
;
return
x
;
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
bf8_t
>
{
// negative zero nan mode with exp bias = 16
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__
__device__
static
constexpr
bf8_t
Min
()
{
return
bf8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_t
Max
()
{
return
bf8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_t
Lowest
()
{
return
bf8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment