Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
803b9db8
Commit
803b9db8
authored
Jul 12, 2023
by
Rostyslav Geyyer
Browse files
Refactor f8_t to add bf8_t
parent
1cf50031
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
124 additions
and
64 deletions
+124
-64
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+104
-50
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+10
-4
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+10
-10
No files found.
include/ck/utility/data_type.hpp
View file @
803b9db8
...
...
@@ -12,7 +12,25 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
#endif
using
f8_t
=
uint8_t
;
struct
f8_t
{
uint8_t
data
;
using
type
=
f8_t
;
using
data_type
=
uint8_t
;
__host__
__device__
f8_t
()
=
default
;
__host__
__device__
f8_t
(
uint8_t
init
);
};
struct
bf8_t
{
uint8_t
data
;
using
type
=
bf8_t
;
using
data_type
=
uint8_t
;
__host__
__device__
bf8_t
()
=
default
;
};
// vector_type
template
<
typename
T
,
index_t
N
>
...
...
@@ -187,8 +205,10 @@ struct vector_type<T, 1>
template
<
typename
T
>
struct
vector_type
<
T
,
2
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
using
type
=
d2_t
;
...
...
@@ -237,9 +257,11 @@ struct vector_type<T, 2>
template
<
typename
T
>
struct
vector_type
<
T
,
4
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d4_t
__attribute__
((
ext_vector_type
(
4
)));
using
type
=
d4_t
;
...
...
@@ -299,10 +321,12 @@ struct vector_type<T, 4>
template
<
typename
T
>
struct
vector_type
<
T
,
8
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T_adjusted
d8_t
__attribute__
((
ext_vector_type
(
8
)));
using
type
=
d8_t
;
...
...
@@ -373,11 +397,13 @@ struct vector_type<T, 8>
template
<
typename
T
>
struct
vector_type
<
T
,
16
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T_adjusted
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T_adjusted
d16_t
__attribute__
((
ext_vector_type
(
16
)));
using
type
=
d16_t
;
...
...
@@ -459,12 +485,14 @@ struct vector_type<T, 16>
template
<
typename
T
>
struct
vector_type
<
T
,
32
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T_adjusted
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T_adjusted
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T_adjusted
d32_t
__attribute__
((
ext_vector_type
(
32
)));
using
type
=
d32_t
;
...
...
@@ -555,13 +583,15 @@ struct vector_type<T, 32>
template
<
typename
T
>
struct
vector_type
<
T
,
64
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T
d64_t
__attribute__
((
ext_vector_type
(
64
)));
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T_adjusted
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T_adjusted
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T_adjusted
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T_adjusted
d64_t
__attribute__
((
ext_vector_type
(
64
)));
using
type
=
d64_t
;
...
...
@@ -663,14 +693,16 @@ struct vector_type<T, 64>
template
<
typename
T
>
struct
vector_type
<
T
,
128
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T
d64_t
__attribute__
((
ext_vector_type
(
64
)));
typedef
T
d128_t
__attribute__
((
ext_vector_type
(
128
)));
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T_adjusted
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T_adjusted
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T_adjusted
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T_adjusted
d64_t
__attribute__
((
ext_vector_type
(
64
)));
typedef
T_adjusted
d128_t
__attribute__
((
ext_vector_type
(
128
)));
using
type
=
d128_t
;
...
...
@@ -781,15 +813,17 @@ struct vector_type<T, 128>
template
<
typename
T
>
struct
vector_type
<
T
,
256
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T
d64_t
__attribute__
((
ext_vector_type
(
64
)));
typedef
T
d128_t
__attribute__
((
ext_vector_type
(
128
)));
typedef
T
d256_t
__attribute__
((
ext_vector_type
(
256
)));
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T_adjusted
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T_adjusted
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T_adjusted
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T_adjusted
d64_t
__attribute__
((
ext_vector_type
(
64
)));
typedef
T_adjusted
d128_t
__attribute__
((
ext_vector_type
(
128
)));
typedef
T_adjusted
d256_t
__attribute__
((
ext_vector_type
(
256
)));
using
type
=
d256_t
;
...
...
@@ -1013,14 +1047,34 @@ struct NumericLimits<f8_t>
static
constexpr
uint8_t
binary_max
=
0x77
;
// 0b01110111
static
constexpr
uint8_t
binary_lowest
=
0xF7
;
// 0b11110111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
__host__
__device__
static
f8_t
Min
()
{
f8_t
x
;
x
.
data
=
binary_min
;
return
x
;
}
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
bit_cast
<
f8_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
bit_cast
<
f8_t
>
(
binary_max
);
}
__host__
__device__
static
f8_t
Max
()
{
f8_t
x
;
x
.
data
=
binary_max
;
return
x
;
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
bit_cast
<
f8_t
>
(
binary_lowest
);
}
__host__
__device__
static
f8_t
Lowest
()
{
f8_t
x
;
x
.
data
=
binary_lowest
;
return
x
;
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
bit_cast
<
f8_t
>
(
binary_qnan
);
}
__host__
__device__
static
f8_t
QuietNaN
()
{
f8_t
x
;
x
.
data
=
binary_qnan
;
return
x
;
}
};
}
// namespace ck
include/ck/utility/f8_utils.hpp
View file @
803b9db8
...
...
@@ -23,7 +23,7 @@ namespace ck::utils {
namespace
{
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f
8_t
run_cast_to_f8
(
T
x
,
uint32_t
rng
)
__host__
__device__
uint
8_t
run_cast_to_f8
(
T
x
,
uint32_t
rng
)
{
// check data type
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
...
...
@@ -133,7 +133,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
}
template
<
typename
T
,
bool
negative_zero_nan
>
__host__
__device__
T
run_cast_from_f8
(
f
8_t
x
)
__host__
__device__
T
run_cast_from_f8
(
uint
8_t
x
)
{
// check data type
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
...
...
@@ -222,7 +222,7 @@ __host__ __device__ T run_cast_from_f8(f8_t x)
}
// namespace
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f
8_t
cast_to_f8
(
T
x
,
uint32_t
rng
)
__host__
__device__
uint
8_t
cast_to_f8
(
T
x
,
uint32_t
rng
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
...
...
@@ -233,7 +233,7 @@ __host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
}
template
<
typename
T
,
bool
negative_zero_nan
>
__host__
__device__
T
cast_from_f8
(
f
8_t
x
)
__host__
__device__
T
cast_from_f8
(
uint
8_t
x
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
...
...
@@ -248,3 +248,9 @@ __host__ __device__ T cast_from_f8(f8_t x)
}
}
// namespace ck::utils
// f8_t constuctor impl
inline
__host__
__device__
ck
::
f8_t
::
f8_t
(
uint8_t
init
)
{
data
=
init
;
}
include/ck/utility/type_convert.hpp
View file @
803b9db8
...
...
@@ -106,8 +106,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
f8_t
(
utils
::
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
)
)
;
}
// convert fp8 to fp32
...
...
@@ -115,7 +115,7 @@ template <>
inline
__host__
__device__
float
type_convert
<
float
,
f8_t
>
(
f8_t
x
)
{
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
float
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
float
,
negative_zero_nan
>
(
x
.
data
);
}
// convert fp16 to fp8
...
...
@@ -126,8 +126,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
f8_t
(
utils
::
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
)
)
;
}
// convert fp8 to fp16
...
...
@@ -135,7 +135,7 @@ template <>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_t
>
(
f8_t
x
)
{
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
half_t
,
negative_zero_nan
>
(
x
.
data
);
}
// Declare a template function for bf16 conversion using RTN
...
...
@@ -209,8 +209,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
f8_t
(
utils
::
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
)
)
;
}
// convert fp16 to fp8 with stochastic rounding
...
...
@@ -223,8 +223,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
f8_t
(
utils
::
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
)
)
;
}
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment