Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4c47048f
Commit
4c47048f
authored
Nov 04, 2024
by
Rostyslav Geyyer
Browse files
Add stochastic rounding tests
parent
b73f83fd
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
2 deletions
+60
-2
include/ck/ck.hpp
include/ck/ck.hpp
+3
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+27
-1
test/data_type/test_fp4.cpp
test/data_type/test_fp4.cpp
+30
-1
No files found.
include/ck/ck.hpp
View file @
4c47048f
...
...
@@ -156,6 +156,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0
// set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F4_CONVERSION 0
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
...
...
include/ck/utility/type_convert.hpp
View file @
4c47048f
...
...
@@ -526,11 +526,37 @@ inline __host__ __device__ f4_t f4_convert_rne(float x)
#endif
}
// convert fp32 to fp4 with stochastic rounding
inline
__host__
__device__
f4_t
f4_convert_sr
(
float
x
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
// union
// {
// float fval;
// uint32_t i32val;
// uint8_t i8val[4]; // not endian independent
// } val;
// val.fval = x;
// uint32_t ival = 0;
// const float max_fp8 = 240.0f;
// // if x is not +/- infinity or nan
// if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
// // clip float value
// val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
// ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false ->
// WORD0 val.i32val = ival; return val.i8val[0];
#else
return
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
,
rng
);
#endif
}
// convert fp32 to fp4
template
<
>
inline
__host__
__device__
f4_t
type_convert
<
f4_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F
8
_CONVERSION
#if CK_USE_SR_F
4
_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
...
...
test/data_type/test_fp4.cpp
View file @
4c47048f
...
...
@@ -6,6 +6,7 @@
#include "ck/utility/type_convert.hpp"
using
ck
::
f4_convert_rne
;
using
ck
::
f4_convert_sr
;
using
ck
::
f4_t
;
using
ck
::
type_convert
;
...
...
@@ -19,10 +20,11 @@ TEST(FP8, NumericLimits)
EXPECT_EQ
(
ck
::
NumericLimits
<
f4_t
>::
MaxSubnorm
(),
f4_t
{
0x1
});
}
TEST
(
FP
8
,
ConvertFP32Nearest
)
TEST
(
FP
4
,
ConvertFP32Nearest
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// set maximum fp4 value
float
max_fp4
=
6.0
f
;
// convert 0 float to fp4 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f4_convert_rne
(
0.0
f
)),
abs_tol
);
...
...
@@ -44,3 +46,30 @@ TEST(FP8, ConvertFP32Nearest)
neg_float
=
-
0.5
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f4_convert_rne
(
neg_float
)),
abs_tol
);
}
TEST
(
FP4
,
ConvertFP32Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// set maximum fp4 value
float
max_fp4
=
6.0
f
;
// convert 0 float to fp4 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f4_convert_sr
(
0.0
f
)),
abs_tol
);
// convert maximal f4_t to float and check if equal to 6.0
ASSERT_NEAR
(
max_fp4
,
type_convert
<
float
>
(
f4_convert_sr
(
max_fp4
)),
abs_tol
);
// convert maximal float to fp4 and back, check if clipped to 6.0
ASSERT_NEAR
(
max_fp4
,
type_convert
<
float
>
(
f4_convert_sr
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
// positive norm float value to fp4 and back, check if holds
float
pos_float
=
1.0
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f4_convert_sr
(
pos_float
)),
abs_tol
);
// negative norm float value to fp4 and back, check if holds
float
neg_float
=
-
1.5
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f4_convert_sr
(
neg_float
)),
abs_tol
);
// positive subnorm float value to fp4 and back, check if holds
pos_float
=
0.5
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f4_convert_sr
(
pos_float
)),
abs_tol
);
// negative subnorm float value to fp4 and back, check if holds
neg_float
=
-
0.5
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f4_convert_sr
(
neg_float
)),
abs_tol
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment