Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2bd1b9cf
Commit
2bd1b9cf
authored
Oct 11, 2024
by
Andriy Roshchenko
Browse files
Implement FP8OCP tests for half_t type conversions.
parent
c76b765a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
195 additions
and
15 deletions
+195
-15
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+77
-13
test/data_type/test_fp8_ocp.cpp
test/data_type/test_fp8_ocp.cpp
+118
-2
No files found.
include/ck/utility/data_type.hpp
View file @
2bd1b9cf
...
@@ -359,10 +359,9 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
...
@@ -359,10 +359,9 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
// The conversion function is from rocblas
// The conversion function is from rocblas
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
// This has been modified to handle double types as well
// This has been modified to handle double types as well
template
<
typename
T
,
bool
is_fnuz
>
template
<
typename
T
,
int
wm
,
int
we
,
bool
is_fnuz
,
bool
clip
=
false
>
__host__
__device__
static
inline
T
cast_from_f8
(
fp8_storage_t
x
,
int
wm
,
int
we
,
bool
clip
=
false
)
__host__
__device__
static
inline
T
cast_from_f8
(
fp8_storage_t
x
)
{
{
// TODO: synchronize with f8_utils.hpp implementation for FNUZ
constexpr
bool
is_half
=
__hip_internal
::
is_same
<
T
,
_Float16
>::
value
;
constexpr
bool
is_half
=
__hip_internal
::
is_same
<
T
,
_Float16
>::
value
;
constexpr
bool
is_float
=
__hip_internal
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_float
=
__hip_internal
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_double
=
__hip_internal
::
is_same
<
T
,
double
>::
value
;
constexpr
bool
is_double
=
__hip_internal
::
is_same
<
T
,
double
>::
value
;
...
@@ -514,7 +513,8 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x, int wm, int we
...
@@ -514,7 +513,8 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x, int wm, int we
}
}
#if CK_FP8_CVT_FAST_PATH
#if CK_FP8_CVT_FAST_PATH
static
__device__
float
cast_to_f32_from_f8
(
fp8_storage_t
v
,
uint32_t
interpret
)
template
<
ck_fp8_interpretation_t
interpret
>
static
__device__
float
cast_to_f32_from_f8
(
fp8_storage_t
v
)
{
{
union
union
{
{
...
@@ -523,10 +523,18 @@ static __device__ float cast_to_f32_from_f8(fp8_storage_t v, uint32_t interpret)
...
@@ -523,10 +523,18 @@ static __device__ float cast_to_f32_from_f8(fp8_storage_t v, uint32_t interpret)
}
val
;
}
val
;
val
.
i8val
[
0
]
=
v
;
val
.
i8val
[
0
]
=
v
;
float
fval
=
(
interpret
==
internal
::
CK_E4M3_FNUZ
)
||
(
interpret
==
internal
::
CK_E4M3_OCP
)
static_assert
(
interpret
==
CK_E4M3_FNUZ
||
interpret
==
CK_E4M3_OCP
||
?
__builtin_amdgcn_cvt_f32_fp8
(
val
.
i32val
,
0
)
interpret
==
CK_E5M2_FNUZ
||
interpret
==
CK_E5M2_OCP
,
:
__builtin_amdgcn_cvt_f32_bf8
(
val
.
i32val
,
0
);
"Only FNUZ and OCP interpretations are supported"
);
return
fval
;
if
constexpr
((
interpret
==
internal
::
CK_E4M3_FNUZ
)
||
(
interpret
==
internal
::
CK_E4M3_OCP
))
{
return
__builtin_amdgcn_cvt_f32_fp8
(
val
.
i32val
,
0
);
}
else
{
return
__builtin_amdgcn_cvt_f32_bf8
(
val
.
i32val
,
0
);
}
}
}
// The conversion function is from rocblas
// The conversion function is from rocblas
...
@@ -659,6 +667,32 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
...
@@ -659,6 +667,32 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
#endif // CK_FP8_CVT_FAST_PATH
#endif // CK_FP8_CVT_FAST_PATH
}
}
/**
* \brief convert half_t to @p fp8_storage_t
*
* \tparam sat saturation of fp8
* \tparam interp interpretation of fp8
* \tparam stochastic_rounding switch between RNE and SR
* \param x half_t value
* \return fp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
ck_saturation_t
sat
=
CK_SATFINITE
,
bool
stochastic_rounding
=
false
>
#if CK_FP8_CVT_FAST_PATH
__host__
__device__
static
inline
fp8_storage_t
cvt_half_t_to_fp8
(
const
half_t
x
)
{
internal
::
__is_interpret_supported
(
interp
);
#elif CK_USE_OCP_FP8
__host__
__device__
static
inline
fp8_storage_t
cvt_half_t_to_fp8
(
const
half_t
x
)
{
#else
__host__
static
inline
fp8_storage_t
cvt_half_t_to_fp8
(
const
half_t
x
)
{
#endif
return
cvt_float_to_fp8
<
interp
,
sat
,
stochastic_rounding
>
(
static_cast
<
float
>
(
x
));
}
/* For fp8 fnuz types, finite and NaN values are supported. Zero is unsigned.
/* For fp8 fnuz types, finite and NaN values are supported. Zero is unsigned.
Inf are not supported. This gives us one additional number to represent.
Inf are not supported. This gives us one additional number to represent.
NaN are represented by 1-0000-000 or 1-00000-00 */
NaN are represented by 1-0000-000 or 1-00000-00 */
...
@@ -706,15 +740,31 @@ struct f8_ocp_t
...
@@ -706,15 +740,31 @@ struct f8_ocp_t
}
}
#if CK_USE_OCP_FP8
#if CK_USE_OCP_FP8
__host__
__device__
explicit
operator
float
()
const
{
__host__
__device__
explicit
operator
float
()
const
{
#else
#else
__host__
explicit
operator
float
()
const
__host__
explicit
operator
float
()
const
{
{
#endif
#endif
#if CK_FP8_CVT_FAST_PATH
#if CK_FP8_CVT_FAST_PATH
return
internal
::
cast_to_f32_from_f8
(
this
->
data
,
default_interpret
);
return
internal
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
);
#else
#else
return
internal
::
cast_from_f8
<
float
,
false
>
(
this
->
data
,
wm
,
we
);
return
internal
::
cast_from_f8
<
float
,
wm
,
we
,
false
>
(
this
->
data
);
// XXX: clip==false must be consistent with operator half_t
#endif
}
#if CK_USE_OCP_FP8
__host__
__device__
explicit
operator
half_t
()
const
{
#else
__host__
explicit
operator
half_t
()
const
{
#endif
#if CK_FP8_CVT_FAST_PATH
return
static_cast
<
half_t
>
(
internal
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
));
#else
return
internal
::
cast_from_f8
<
half_t
,
wm
,
we
,
false
>
(
this
->
data
);
// XXX: clip==false must be consistent with operator float
#endif
#endif
}
}
};
// namespace ck
};
// namespace ck
...
@@ -752,12 +802,18 @@ inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, float>(float x)
...
@@ -752,12 +802,18 @@ inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, float>(float x)
return
f8_ocp_t
{
return
f8_ocp_t
{
internal
::
cvt_float_to_fp8
<
f8_ocp_t
::
default_interpret
,
f8_ocp_t
::
default_saturation
>
(
x
)};
internal
::
cvt_float_to_fp8
<
f8_ocp_t
::
default_interpret
,
f8_ocp_t
::
default_saturation
>
(
x
)};
}
}
// convert half_t to fp8 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8_ocp_t
f8_convert_rne
<
f8_ocp_t
,
half_t
>
(
half_t
x
)
{
return
f8_ocp_t
{
internal
::
cvt_half_t_to_fp8
<
f8_ocp_t
::
default_interpret
,
f8_ocp_t
::
default_saturation
>
(
x
)};
}
// Declare a template function for fp8 conversion using RNE
// Declare a template function for fp8 conversion using RNE
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
// convert fp32 to fp8 with
rounding to nearest even
// convert fp32 to fp8 with
stochastic rounding
template
<
>
template
<
>
inline
__host__
__device__
f8_ocp_t
f8_convert_sr
<
f8_ocp_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_ocp_t
f8_convert_sr
<
f8_ocp_t
,
float
>
(
float
x
)
{
{
...
@@ -765,6 +821,14 @@ inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
...
@@ -765,6 +821,14 @@ inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
internal
::
cvt_float_to_fp8
<
f8_ocp_t
::
default_interpret
,
f8_ocp_t
::
default_saturation
,
true
>
(
internal
::
cvt_float_to_fp8
<
f8_ocp_t
::
default_interpret
,
f8_ocp_t
::
default_saturation
,
true
>
(
x
)};
x
)};
}
}
// convert half_t to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_ocp_t
f8_convert_sr
<
f8_ocp_t
,
half_t
>
(
half_t
x
)
{
return
f8_ocp_t
{
internal
::
cvt_half_t_to_fp8
<
f8_ocp_t
::
default_interpret
,
f8_ocp_t
::
default_saturation
,
true
>
(
x
)};
}
#if CK_USE_OCP_FP8
#if CK_USE_OCP_FP8
using
f8_t
=
f8_ocp_t
;
using
f8_t
=
f8_ocp_t
;
...
...
test/data_type/test_fp8_ocp.cpp
View file @
2bd1b9cf
...
@@ -130,6 +130,122 @@ TEST(FP8OCP, ConvertFP32Stochastic)
...
@@ -130,6 +130,122 @@ TEST(FP8OCP, ConvertFP32Stochastic)
ASSERT_TRUE
((
f8_nan
.
data
&
0x7f
)
==
0x7f
);
ASSERT_TRUE
((
f8_nan
.
data
&
0x7f
)
==
0x7f
);
}
}
TEST
(
FP8OCP
,
ConvertFP16Nearest
)
{}
TEST
(
FP8OCP
,
ConvertFP16Nearest
)
{
// fix the tolerance value
constexpr
half_t
half_t_tol
=
1e-3
;
constexpr
half_t
half_t_zero
=
0.0
;
// convert 0 half_t to fp8 and back, check if holds
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
half_t_zero
)),
half_t_zero
);
// convert minimal half_t to fp8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
half_t_tol
);
const
auto
max_f8_t_half_t
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
());
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
max_f8_t_half_t
)),
half_t_zero
);
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR
(
max_f8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
half_t_zero
);
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
(),
f8_convert_rne
<
f8_ocp_t
>
(
type_convert
<
half_t
>
(
std
::
numeric_limits
<
float
>::
infinity
())));
// positive norm half_t value to fp8 and back, check if holds
half_t
pos_half_t
{
0.017578125
f
};
ASSERT_NEAR
(
pos_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
pos_half_t
)),
half_t_tol
);
// smallest normal fp8 value to fp8 and back, check if holds
half_t
neg_half_t
{
-
0.015625
f
};
//-2^-6
ASSERT_NEAR
(
neg_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
neg_half_t
)),
half_t_zero
);
// positive subnorm half_t value to fp8 and back, check if holds
pos_half_t
=
half_t
{
0.00390625
f
};
ASSERT_NEAR
(
pos_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
pos_half_t
)),
half_t_tol
);
// min subnorm fp8 value to fp8 and back, check if holds
neg_half_t
=
half_t
{
-
0.001953125
f
};
//-2^-9
ASSERT_NEAR
(
neg_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
neg_half_t
)),
half_t_zero
);
// smaller than min subnorm fp8 value to fp8 must be zero
auto
less_than_min_subnorm
=
half_t
{
0.0009765625
f
};
// 2^-10
ASSERT_EQ
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
less_than_min_subnorm
)));
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto
f8_nan
=
f8_convert_rne
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
());
ASSERT_TRUE
(
ck
::
internal
::
ocp_f8_is_nan
(
f8_nan
.
data
));
}
TEST
(
FP8OCP
,
ConvertFP16Stochastic
)
{}
TEST
(
FP8OCP
,
ConvertFP16Stochastic
)
{
// fix the tolerance value
constexpr
half_t
half_t_tol
=
1e-3
;
constexpr
half_t
half_t_zero
=
0.0
;
constexpr
auto
min_subnorm_fp8
=
0.001953125
f
;
// 2^-9
// convert 0 half_t to fp8 and back, check if holds
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
half_t_zero
)),
half_t_zero
);
// convert minimal half_t (6.103515625e-05) to fp8 and back
// alternates between 0 and 2^-9 (0.001953125)
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
min_subnorm_fp8
));
const
auto
max_f8_t_half_t
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
());
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
max_f8_t_half_t
)),
half_t_zero
);
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR
(
max_f8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
half_t_zero
);
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
(),
f8_convert_sr
<
f8_ocp_t
>
(
type_convert
<
half_t
>
(
std
::
numeric_limits
<
float
>::
infinity
())));
// positive norm half_t value to fp8 and back, check if holds
half_t
pos_half_t
{
0.017578125
f
};
ASSERT_NEAR
(
pos_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
pos_half_t
)),
half_t_tol
);
// smallest normal fp8 value to fp8 and back, check if holds
half_t
neg_half_t
{
-
0.015625
f
};
//-2^-6
ASSERT_NEAR
(
neg_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
neg_half_t
)),
half_t_zero
);
// positive subnorm half_t value to fp8 and back, check if holds
pos_half_t
=
half_t
{
0.00390625
f
};
ASSERT_NEAR
(
pos_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
pos_half_t
)),
half_t_tol
);
// min subnorm fp8 value to fp8 and back, check if holds
neg_half_t
=
half_t
{
-
min_subnorm_fp8
};
//-2^-9
ASSERT_NEAR
(
neg_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
neg_half_t
)),
half_t_zero
);
// smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9
auto
less_than_min_subnorm
=
half_t
{
0.0009765625
f
};
// 2^-10
ASSERT_NEAR
(
type_convert
<
float
>
(
half_t_zero
),
type_convert
<
float
>
(
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
less_than_min_subnorm
))),
min_subnorm_fp8
);
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto
f8_nan
=
f8_convert_sr
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
());
ASSERT_TRUE
(
ck
::
internal
::
ocp_f8_is_nan
(
f8_nan
.
data
));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment