Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
988fab58
"vscode:/vscode.git/clone" did not exist on "4d0fdcd525bec25e30d3cfe0fd747728761694ff"
Commit
988fab58
authored
Nov 09, 2023
by
Umang Yadav
Browse files
add unit-tests for fp8e4m3fnuz
parent
30005e6a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
23 deletions
+48
-23
src/include/migraphx/migraphx_float8.hpp
src/include/migraphx/migraphx_float8.hpp
+12
-20
src/include/migraphx/migraphx_hip_f8_impl.hpp
src/include/migraphx/migraphx_hip_f8_impl.hpp
+2
-2
test/float8.cpp
test/float8.cpp
+34
-1
No files found.
src/include/migraphx/migraphx_float8.hpp
View file @
988fab58
...
...
@@ -106,13 +106,13 @@ struct hip_f8
// default constructor
MIGRAPHX_HIP_HOST_DEVICE
constexpr
hip_f8
()
=
default
;
// default copy constructor
MIGRAPHX_HIP_HOST_DEVICE
constexpr
hip_f8
(
const
hip_f8
&
y
)
=
default
;
MIGRAPHX_HIP_HOST_DEVICE
constexpr
hip_f8
(
const
hip_f8
<
T
>
&
y
)
=
default
;
struct
from_bits_t
{
};
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
from_bits_t
from_bits
()
{
return
from_bits_t
();
}
MIGRAPHX_HIP_HOST_DEVICE
constexpr
hip_f8
(
uint8_t
bits
,
from_bits_t
)
:
data
(
bits
)
{}
MIGRAPHX_HIP_HOST_DEVICE
explicit
constexpr
hip_f8
(
uint8_t
bits
,
from_bits_t
)
:
data
(
bits
)
{}
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// device specific optimized F8 down-conversion code
...
...
@@ -481,8 +481,8 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
quiet_NaN
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
>
(
static_cast
<
uint8_t
>
(
MIGRAPHX_FP8_FNUZ
?
0x80
:
0x7
9
));
return
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
(
MIGRAPHX_FP8_FNUZ
?
0x80
:
0x7
F
,
migraphx_fp8
::
hip_f8
<>::
from_bits
(
));
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
max
()
...
...
@@ -503,13 +503,8 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
infinity
()
{
if
constexpr
(
MIGRAPHX_FP8_FNUZ
)
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
(
static_cast
<
uint8_t
>
(
0x80
));
}
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
(
static_cast
<
uint8_t
>
(
0x78
));
return
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
(
MIGRAPHX_FP8_FNUZ
?
0x80
:
0x7F
,
migraphx_fp8
::
hip_f8
<>::
from_bits
());
}
};
...
...
@@ -524,8 +519,9 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
quiet_NaN
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
static_cast
<
uint8_t
>
(
MIGRAPHX_FP8_FNUZ
?
0x80
:
0x7d
));
return
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
(
MIGRAPHX_FP8_FNUZ
?
0x80
:
0x7d
,
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>::
from_bits
());
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
max
()
...
...
@@ -546,13 +542,9 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
infinity
()
{
if
constexpr
(
MIGRAPHX_FP8_FNUZ
)
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
static_cast
<
uint8_t
>
(
0x80
));
}
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
static_cast
<
uint8_t
>
(
0x7c
));
return
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
(
MIGRAPHX_FP8_FNUZ
?
0x80
:
0x7c
,
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>::
from_bits
());
}
};
/*
...
...
src/include/migraphx/migraphx_hip_f8_impl.hpp
View file @
988fab58
...
...
@@ -132,11 +132,11 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t
// handle negative zero
if
((
sizeof
(
T
)
==
4
and
x
==
0x80000000
)
or
(
sizeof
(
T
)
==
2
and
x
==
0x8000
))
{
if
(
we
==
4
or
(
we
==
5
and
negative_zero_nan
)
)
if
(
negative_zero_nan
)
{
return
0
;
}
else
if
(
we
==
5
)
// E5M2
else
{
return
0x80
;
}
...
...
test/float8.cpp
View file @
988fab58
...
...
@@ -166,7 +166,10 @@ TEST_CASE(test_nan_1)
TEST_CASE
(
test_nan_2
)
{
migraphx_fp8
::
fp8e4m3fnuz
fp8_nan
(
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fnuz
>::
quiet_NaN
());
auto
fnan
=
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fnuz
>::
quiet_NaN
();
std
::
cout
<<
uint32_t
(
fnan
.
data
)
<<
std
::
endl
;
migraphx_fp8
::
fp8e4m3fnuz
fp8_nan
(
fnan
.
data
,
migraphx_fp8
::
fp8e4m3fnuz
::
from_bits
());
std
::
cout
<<
uint32_t
(
fp8_nan
.
data
)
<<
std
::
endl
;
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
...
...
@@ -199,4 +202,34 @@ TEST_CASE(test_infinity_3)
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
}
TEST_CASE
(
test_numeric_max_1
)
{
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
migraphx_fp8
::
fp8e4m3fnuz
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fnuz
>::
max
());
}
TEST_CASE
(
test_numeric_max_2
)
{
// gets clipped to max
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fnuz
>::
max
();
migraphx_fp8
::
fp8e4m3fnuz
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fnuz
>::
max
());
}
TEST_CASE
(
test_numeric_lowest_1
)
{
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
migraphx_fp8
::
fp8e4m3fnuz
fp8_lowest
(
flowest
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fnuz
>::
lowest
());
}
TEST_CASE
(
test_numeric_lowest_2
)
{
// gets clipped to lowest
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fnuz
>::
lowest
();
migraphx_fp8
::
fp8e4m3fnuz
fp8_lowest
(
fmin
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fnuz
>::
lowest
());
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment