Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2bd1b9cf
Commit
2bd1b9cf
authored
Oct 11, 2024
by
Andriy Roshchenko
Browse files
Implement FP8OCP tests for half_t type conversions.
parent
c76b765a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
195 additions
and
15 deletions
+195
-15
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+77
-13
test/data_type/test_fp8_ocp.cpp
test/data_type/test_fp8_ocp.cpp
+118
-2
No files found.
include/ck/utility/data_type.hpp
View file @
2bd1b9cf
...
...
@@ -359,10 +359,9 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
// The conversion function is from rocblas
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
// This has been modified to handle double types as well
template
<
typename
T
,
bool
is_fnuz
>
__host__
__device__
static
inline
T
cast_from_f8
(
fp8_storage_t
x
,
int
wm
,
int
we
,
bool
clip
=
false
)
template
<
typename
T
,
int
wm
,
int
we
,
bool
is_fnuz
,
bool
clip
=
false
>
__host__
__device__
static
inline
T
cast_from_f8
(
fp8_storage_t
x
)
{
// TODO: synchronize with f8_utils.hpp implementation for FNUZ
constexpr
bool
is_half
=
__hip_internal
::
is_same
<
T
,
_Float16
>::
value
;
constexpr
bool
is_float
=
__hip_internal
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_double
=
__hip_internal
::
is_same
<
T
,
double
>::
value
;
...
...
@@ -514,7 +513,8 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x, int wm, int we
}
#if CK_FP8_CVT_FAST_PATH
static
__device__
float
cast_to_f32_from_f8
(
fp8_storage_t
v
,
uint32_t
interpret
)
template
<
ck_fp8_interpretation_t
interpret
>
static
__device__
float
cast_to_f32_from_f8
(
fp8_storage_t
v
)
{
union
{
...
...
@@ -523,10 +523,18 @@ static __device__ float cast_to_f32_from_f8(fp8_storage_t v, uint32_t interpret)
}
val
;
val
.
i8val
[
0
]
=
v
;
float
fval
=
(
interpret
==
internal
::
CK_E4M3_FNUZ
)
||
(
interpret
==
internal
::
CK_E4M3_OCP
)
?
__builtin_amdgcn_cvt_f32_fp8
(
val
.
i32val
,
0
)
:
__builtin_amdgcn_cvt_f32_bf8
(
val
.
i32val
,
0
);
return
fval
;
static_assert
(
interpret
==
CK_E4M3_FNUZ
||
interpret
==
CK_E4M3_OCP
||
interpret
==
CK_E5M2_FNUZ
||
interpret
==
CK_E5M2_OCP
,
"Only FNUZ and OCP interpretations are supported"
);
if
constexpr
((
interpret
==
internal
::
CK_E4M3_FNUZ
)
||
(
interpret
==
internal
::
CK_E4M3_OCP
))
{
return
__builtin_amdgcn_cvt_f32_fp8
(
val
.
i32val
,
0
);
}
else
{
return
__builtin_amdgcn_cvt_f32_bf8
(
val
.
i32val
,
0
);
}
}
// The conversion function is from rocblas
...
...
@@ -659,6 +667,32 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
#endif // CK_FP8_CVT_FAST_PATH
}
/**
* \brief convert half_t to @p fp8_storage_t
*
* \tparam sat saturation of fp8
* \tparam interp interpretation of fp8
* \tparam stochastic_rounding switch between RNE and SR
* \param x half_t value
* \return fp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
ck_saturation_t
sat
=
CK_SATFINITE
,
bool
stochastic_rounding
=
false
>
#if CK_FP8_CVT_FAST_PATH
__host__
__device__
static
inline
fp8_storage_t
cvt_half_t_to_fp8
(
const
half_t
x
)
{
internal
::
__is_interpret_supported
(
interp
);
#elif CK_USE_OCP_FP8
__host__
__device__
static
inline
fp8_storage_t
cvt_half_t_to_fp8
(
const
half_t
x
)
{
#else
__host__
static
inline
fp8_storage_t
cvt_half_t_to_fp8
(
const
half_t
x
)
{
#endif
return
cvt_float_to_fp8
<
interp
,
sat
,
stochastic_rounding
>
(
static_cast
<
float
>
(
x
));
}
/* For fp8 fnuz types, finite and NaN values are supported. Zero is unsigned.
Inf are not supported. This gives us one additional number to represent.
NaN are represented by 1-0000-000 or 1-00000-00 */
...
...
@@ -706,15 +740,31 @@ struct f8_ocp_t
}
#if CK_USE_OCP_FP8
__host__
__device__
explicit
operator
float
()
const
{
__host__
__device__
explicit
operator
float
()
const
{
#else
__host__
explicit
operator
float
()
const
{
#endif
#if CK_FP8_CVT_FAST_PATH
return
internal
::
cast_to_f32_from_f8
(
this
->
data
,
default_interpret
);
return
internal
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
);
#else
return
internal
::
cast_from_f8
<
float
,
false
>
(
this
->
data
,
wm
,
we
);
return
internal
::
cast_from_f8
<
float
,
wm
,
we
,
false
>
(
this
->
data
);
// XXX: clip==false must be consistent with operator half_t
#endif
}
#if CK_USE_OCP_FP8
__host__
__device__
explicit
operator
half_t
()
const
{
#else
__host__
explicit
operator
half_t
()
const
{
#endif
#if CK_FP8_CVT_FAST_PATH
return
static_cast
<
half_t
>
(
internal
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
));
#else
return
internal
::
cast_from_f8
<
half_t
,
wm
,
we
,
false
>
(
this
->
data
);
// XXX: clip==false must be consistent with operator float
#endif
}
};
// namespace ck
...
...
@@ -752,12 +802,18 @@ inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, float>(float x)
return
f8_ocp_t
{
internal
::
cvt_float_to_fp8
<
f8_ocp_t
::
default_interpret
,
f8_ocp_t
::
default_saturation
>
(
x
)};
}
// convert half_t to fp8 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8_ocp_t
f8_convert_rne
<
f8_ocp_t
,
half_t
>
(
half_t
x
)
{
return
f8_ocp_t
{
internal
::
cvt_half_t_to_fp8
<
f8_ocp_t
::
default_interpret
,
f8_ocp_t
::
default_saturation
>
(
x
)};
}
// Declare a template function for fp8 conversion using RNE
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
// convert fp32 to fp8 with
rounding to nearest even
// convert fp32 to fp8 with
stochastic rounding
template
<
>
inline
__host__
__device__
f8_ocp_t
f8_convert_sr
<
f8_ocp_t
,
float
>
(
float
x
)
{
...
...
@@ -765,6 +821,14 @@ inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
internal
::
cvt_float_to_fp8
<
f8_ocp_t
::
default_interpret
,
f8_ocp_t
::
default_saturation
,
true
>
(
x
)};
}
// convert half_t to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_ocp_t
f8_convert_sr
<
f8_ocp_t
,
half_t
>
(
half_t
x
)
{
return
f8_ocp_t
{
internal
::
cvt_half_t_to_fp8
<
f8_ocp_t
::
default_interpret
,
f8_ocp_t
::
default_saturation
,
true
>
(
x
)};
}
#if CK_USE_OCP_FP8
using
f8_t
=
f8_ocp_t
;
...
...
test/data_type/test_fp8_ocp.cpp
View file @
2bd1b9cf
...
...
@@ -130,6 +130,122 @@ TEST(FP8OCP, ConvertFP32Stochastic)
ASSERT_TRUE
((
f8_nan
.
data
&
0x7f
)
==
0x7f
);
}
TEST
(
FP8OCP
,
ConvertFP16Nearest
)
{}
TEST
(
FP8OCP
,
ConvertFP16Nearest
)
{
// fix the tolerance value
constexpr
half_t
half_t_tol
=
1e-3
;
constexpr
half_t
half_t_zero
=
0.0
;
// convert 0 half_t to fp8 and back, check if holds
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
half_t_zero
)),
half_t_zero
);
// convert minimal half_t to fp8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
half_t_tol
);
const
auto
max_f8_t_half_t
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
());
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
max_f8_t_half_t
)),
half_t_zero
);
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR
(
max_f8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
half_t_zero
);
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
(),
f8_convert_rne
<
f8_ocp_t
>
(
type_convert
<
half_t
>
(
std
::
numeric_limits
<
float
>::
infinity
())));
// positive norm half_t value to fp8 and back, check if holds
half_t
pos_half_t
{
0.017578125
f
};
ASSERT_NEAR
(
pos_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
pos_half_t
)),
half_t_tol
);
// smallest normal fp8 value to fp8 and back, check if holds
half_t
neg_half_t
{
-
0.015625
f
};
//-2^-6
ASSERT_NEAR
(
neg_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
neg_half_t
)),
half_t_zero
);
// positive subnorm half_t value to fp8 and back, check if holds
pos_half_t
=
half_t
{
0.00390625
f
};
ASSERT_NEAR
(
pos_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
pos_half_t
)),
half_t_tol
);
// min subnorm fp8 value to fp8 and back, check if holds
neg_half_t
=
half_t
{
-
0.001953125
f
};
//-2^-9
ASSERT_NEAR
(
neg_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
neg_half_t
)),
half_t_zero
);
// smaller than min subnorm fp8 value to fp8 must be zero
auto
less_than_min_subnorm
=
half_t
{
0.0009765625
f
};
// 2^-10
ASSERT_EQ
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
less_than_min_subnorm
)));
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto
f8_nan
=
f8_convert_rne
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
());
ASSERT_TRUE
(
ck
::
internal
::
ocp_f8_is_nan
(
f8_nan
.
data
));
}
TEST
(
FP8OCP
,
ConvertFP16Stochastic
)
{}
TEST
(
FP8OCP
,
ConvertFP16Stochastic
)
{
// fix the tolerance value
constexpr
half_t
half_t_tol
=
1e-3
;
constexpr
half_t
half_t_zero
=
0.0
;
constexpr
auto
min_subnorm_fp8
=
0.001953125
f
;
// 2^-9
// convert 0 half_t to fp8 and back, check if holds
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
half_t_zero
)),
half_t_zero
);
// convert minimal half_t (6.103515625e-05) to fp8 and back
// alternates between 0 and 2^-9 (0.001953125)
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
min_subnorm_fp8
));
const
auto
max_f8_t_half_t
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
());
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
max_f8_t_half_t
)),
half_t_zero
);
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR
(
max_f8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
half_t_zero
);
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
(),
f8_convert_sr
<
f8_ocp_t
>
(
type_convert
<
half_t
>
(
std
::
numeric_limits
<
float
>::
infinity
())));
// positive norm half_t value to fp8 and back, check if holds
half_t
pos_half_t
{
0.017578125
f
};
ASSERT_NEAR
(
pos_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
pos_half_t
)),
half_t_tol
);
// smallest normal fp8 value to fp8 and back, check if holds
half_t
neg_half_t
{
-
0.015625
f
};
//-2^-6
ASSERT_NEAR
(
neg_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
neg_half_t
)),
half_t_zero
);
// positive subnorm half_t value to fp8 and back, check if holds
pos_half_t
=
half_t
{
0.00390625
f
};
ASSERT_NEAR
(
pos_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
pos_half_t
)),
half_t_tol
);
// min subnorm fp8 value to fp8 and back, check if holds
neg_half_t
=
half_t
{
-
min_subnorm_fp8
};
//-2^-9
ASSERT_NEAR
(
neg_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
neg_half_t
)),
half_t_zero
);
// smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9
auto
less_than_min_subnorm
=
half_t
{
0.0009765625
f
};
// 2^-10
ASSERT_NEAR
(
type_convert
<
float
>
(
half_t_zero
),
type_convert
<
float
>
(
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
less_than_min_subnorm
))),
min_subnorm_fp8
);
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto
f8_nan
=
f8_convert_sr
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
());
ASSERT_TRUE
(
ck
::
internal
::
ocp_f8_is_nan
(
f8_nan
.
data
));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment