Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c76b765a
Commit
c76b765a
authored
Oct 11, 2024
by
Andriy Roshchenko
Browse files
Implement FP8OCP test for stochastic rounding mode.
parent
d40d1ff1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
105 additions
and
25 deletions
+105
-25
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+53
-24
test/data_type/test_fp8_ocp.cpp
test/data_type/test_fp8_ocp.cpp
+52
-1
No files found.
include/ck/utility/data_type.hpp
View file @
c76b765a
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#pragma once
#pragma once
#include "ck/utility/statically_indexed_array.hpp"
#include "ck/utility/statically_indexed_array.hpp"
#include "ck/utility/random_gen.hpp"
#ifdef CK_USE_FNUZ_FP8
#ifdef CK_USE_FNUZ_FP8
#define CK_USE_FNUZ_FP8 1
#define CK_USE_FNUZ_FP8 1
...
@@ -240,7 +241,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
...
@@ -240,7 +241,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
}
}
// First need to check if it is normal or denorm as there is a difference of
// First need to check if it is normal or denorm as there is a difference of
// implict 1 Then need to adjust the exponent to align with the F8 exponent,
// implic
i
t 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
// need to check whether there is carry and adjust exponent and mantissa again
// need to check whether there is carry and adjust exponent and mantissa again
...
@@ -275,7 +276,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
...
@@ -275,7 +276,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
{
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
actual exponent is -7, it is actually larger due to the implic
i
t 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
...
@@ -303,7 +304,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
...
@@ -303,7 +304,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
else
if
(
exponent_diff
==
-
1
)
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1ull
<<
mfmt
);
bool
implicit_one
=
mantissa
&
(
1ull
<<
mfmt
);
// if there is no implict 1, it means the f8 is denormal and need to adjust
// if there is no implic
i
t 1, it means the f8 is denormal and need to adjust
// to denorm exponent
// to denorm exponent
f8_exponent
=
f8_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
f8_bias
-
(
implicit_one
?
0
:
1
);
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
f8_bias
-
(
implicit_one
?
0
:
1
);
...
@@ -530,9 +531,8 @@ static __device__ float cast_to_f32_from_f8(fp8_storage_t v, uint32_t interpret)
...
@@ -530,9 +531,8 @@ static __device__ float cast_to_f32_from_f8(fp8_storage_t v, uint32_t interpret)
// The conversion function is from rocblas
// The conversion function is from rocblas
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
template
<
bool
stochastic_rounding
=
false
>
template
<
ck_fp8_interpretation_t
interpret
,
bool
saturate
,
bool
stochastic_rounding
=
false
>
static
__device__
fp8_storage_t
static
__device__
fp8_storage_t
cast_to_f8_from_f32
(
float
v
,
unsigned
int
rng
=
0
)
cast_to_f8_from_f32
(
float
v
,
bool
saturate
,
ck_fp8_interpretation_t
interpret
,
unsigned
int
rng
=
0
)
{
{
fp8_storage_t
i8data
;
fp8_storage_t
i8data
;
union
union
...
@@ -545,9 +545,9 @@ cast_to_f8_from_f32(float v, bool saturate, ck_fp8_interpretation_t interpret, u
...
@@ -545,9 +545,9 @@ cast_to_f8_from_f32(float v, bool saturate, ck_fp8_interpretation_t interpret, u
unsigned
int
ival
=
0
;
unsigned
int
ival
=
0
;
val
.
fval
=
v
;
val
.
fval
=
v
;
if
(
saturate
)
if
constexpr
(
saturate
)
{
{
if
(
interpret
==
CK_E4M3_FNUZ
)
if
constexpr
(
interpret
==
CK_E4M3_FNUZ
)
{
{
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
{
/// propagate NAN/INF, no clipping
...
@@ -570,7 +570,7 @@ cast_to_f8_from_f32(float v, bool saturate, ck_fp8_interpretation_t interpret, u
...
@@ -570,7 +570,7 @@ cast_to_f8_from_f32(float v, bool saturate, ck_fp8_interpretation_t interpret, u
}
}
}
}
if
(
stochastic_rounding
)
if
constexpr
(
stochastic_rounding
)
{
{
ival
=
(
interpret
==
CK_E4M3_FNUZ
)
||
(
interpret
==
CK_E4M3_OCP
)
ival
=
(
interpret
==
CK_E4M3_FNUZ
)
||
(
interpret
==
CK_E4M3_OCP
)
?
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
)
?
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
)
...
@@ -597,43 +597,59 @@ cast_to_f8_from_f32(float v, bool saturate, ck_fp8_interpretation_t interpret, u
...
@@ -597,43 +597,59 @@ cast_to_f8_from_f32(float v, bool saturate, ck_fp8_interpretation_t interpret, u
/**
/**
* \brief convert float to @p fp8_storage_t
* \brief convert float to @p fp8_storage_t
*
*
* \tparam interp interpretation of fp8
* \tparam sat saturation of fp8
* \tparam sat saturation of fp8
* \param f float number
* \param f float number
* \param interp interpretation of fp8
* \return fp8_storage_t
* \return fp8_storage_t
*/
*/
template
<
ck_saturation_t
sat
=
CK_SATFINITE
>
template
<
ck_fp8_interpretation_t
interp
,
ck_saturation_t
sat
=
CK_SATFINITE
,
bool
stochastic_rounding
=
false
>
#if CK_FP8_CVT_FAST_PATH
#if CK_FP8_CVT_FAST_PATH
__host__
__device__
static
inline
fp8_storage_t
__host__
__device__
static
inline
fp8_storage_t
cvt_float_to_fp8
(
const
float
f
)
cvt_float_to_fp8
(
const
float
f
,
const
ck_fp8_interpretation_t
interp
)
{
{
internal
::
__is_interpret_supported
(
interp
);
internal
::
__is_interpret_supported
(
interp
);
return
internal
::
cast_to_f8_from_f32
<
false
>
(
f
,
sat
==
CK_SATFINITE
,
interp
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
}
return
internal
::
cast_to_f8_from_f32
<
interp
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
#else
#else
#if CK_USE_OCP_FP8
#if CK_USE_OCP_FP8
__host__
__device__
static
inline
fp8_storage_t
__host__
__device__
static
inline
fp8_storage_t
cvt_float_to_fp8
(
const
float
f
)
cvt_float_to_fp8
(
const
float
f
,
const
ck_fp8_interpretation_t
interp
)
{
{
#else
#else
__host__
static
inline
fp8_storage_t
cvt_float_to_fp8
(
const
float
f
,
__host__
static
inline
fp8_storage_t
cvt_float_to_fp8
(
const
float
f
)
const
ck_fp8_interpretation_t
interp
)
{
{
#endif
#endif
if
(
interp
==
CK_E4M3_FNUZ
)
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
{
return
internal
::
cast_to_f8
<
float
,
3
,
4
,
true
,
sat
==
CK_SATFINITE
>
(
f
);
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
}
if
constexpr
(
interp
==
CK_E4M3_FNUZ
)
{
return
internal
::
cast_to_f8
<
float
,
3
,
4
,
true
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
}
}
else
if
(
interp
==
CK_E5M2_FNUZ
)
else
if
(
interp
==
CK_E5M2_FNUZ
)
{
{
return
internal
::
cast_to_f8
<
float
,
2
,
5
,
true
,
sat
==
CK_SATFINITE
>
(
f
);
return
internal
::
cast_to_f8
<
float
,
2
,
5
,
true
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
}
}
else
if
(
interp
==
CK_E4M3_OCP
)
else
if
(
interp
==
CK_E4M3_OCP
)
{
{
return
internal
::
cast_to_f8
<
float
,
3
,
4
,
false
,
sat
==
CK_SATFINITE
>
(
f
);
return
internal
::
cast_to_f8
<
float
,
3
,
4
,
false
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
}
}
else
if
(
interp
==
CK_E5M2_OCP
)
else
if
(
interp
==
CK_E5M2_OCP
)
{
{
return
internal
::
cast_to_f8
<
float
,
2
,
5
,
false
,
sat
==
CK_SATFINITE
>
(
f
);
return
internal
::
cast_to_f8
<
float
,
2
,
5
,
false
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
}
}
else
else
{
{
...
@@ -734,7 +750,20 @@ template <>
...
@@ -734,7 +750,20 @@ template <>
inline
__host__
__device__
f8_ocp_t
f8_convert_rne
<
f8_ocp_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_ocp_t
f8_convert_rne
<
f8_ocp_t
,
float
>
(
float
x
)
{
{
return
f8_ocp_t
{
return
f8_ocp_t
{
internal
::
cvt_float_to_fp8
<
f8_ocp_t
::
default_saturation
>
(
x
,
f8_ocp_t
::
default_interpret
)};
internal
::
cvt_float_to_fp8
<
f8_ocp_t
::
default_interpret
,
f8_ocp_t
::
default_saturation
>
(
x
)};
}
// Declare a template function for fp8 conversion using RNE
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
// convert fp32 to fp8 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8_ocp_t
f8_convert_sr
<
f8_ocp_t
,
float
>
(
float
x
)
{
return
f8_ocp_t
{
internal
::
cvt_float_to_fp8
<
f8_ocp_t
::
default_interpret
,
f8_ocp_t
::
default_saturation
,
true
>
(
x
)};
}
}
#if CK_USE_OCP_FP8
#if CK_USE_OCP_FP8
...
...
test/data_type/test_fp8_ocp.cpp
View file @
c76b765a
...
@@ -77,7 +77,58 @@ TEST(FP8OCP, ConvertFP32Nearest)
...
@@ -77,7 +77,58 @@ TEST(FP8OCP, ConvertFP32Nearest)
ASSERT_TRUE
((
f8_nan
.
data
&
0x7f
)
==
0x7f
);
ASSERT_TRUE
((
f8_nan
.
data
&
0x7f
)
==
0x7f
);
}
}
TEST
(
FP8OCP
,
ConvertFP32Stochastic
)
{}
TEST
(
FP8OCP
,
ConvertFP32Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
0.0
f
)),
0.0
f
);
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
const
auto
max_f8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
());
// convert maximal f8_ocp_t to float and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
max_f8_t_float
)),
0.0
f
);
// convert maximal float to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
0.0
f
);
// convert float infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
(),
f8_convert_sr
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()));
// positive norm float value to fp8 and back, check if holds
float
pos_float
=
0.017578125
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
pos_float
)),
abs_tol
);
// smallest normal fp8 value to fp8 and back, check if holds
float
neg_float
=
-
0.015625
f
;
//-2^-6
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
neg_float
)),
0.0
f
);
// positive subnorm float value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
pos_float
)),
abs_tol
);
// min subnorm fp8 value to fp8 and back, check if holds
constexpr
auto
min_subnorm_fp8
=
-
0.001953125
f
;
//-2^-9
ASSERT_NEAR
(
min_subnorm_fp8
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
min_subnorm_fp8
)),
0.0
f
);
// smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9
auto
less_than_min_subnorm
=
0.0009765625
f
;
// 2^-10
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
less_than_min_subnorm
)),
0.001953125
f
);
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto
f8_nan
=
f8_convert_sr
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
quiet_NaN
());
ASSERT_TRUE
((
f8_nan
.
data
&
0x7f
)
==
0x7f
);
}
TEST
(
FP8OCP
,
ConvertFP16Nearest
)
{}
TEST
(
FP8OCP
,
ConvertFP16Nearest
)
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment