Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
269ce6d1
Commit
269ce6d1
authored
Nov 16, 2023
by
Umang Yadav
Browse files
Bugfixes and additional tests
parent
26956f1d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
97 additions
and
13 deletions
+97
-13
src/include/migraphx/float8_impl.hpp
src/include/migraphx/float8_impl.hpp
+9
-8
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+1
-1
test/fp8e4m3fn.cpp
test/fp8e4m3fn.cpp
+22
-1
test/fp8e4m3fnuz.cpp
test/fp8e4m3fnuz.cpp
+29
-1
test/fp8e5m2.cpp
test/fp8e5m2.cpp
+12
-1
test/fp8e5m2fnuz.cpp
test/fp8e5m2fnuz.cpp
+24
-1
No files found.
src/include/migraphx/float8_impl.hpp
View file @
269ce6d1
...
...
@@ -134,15 +134,15 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0)
int
f8_exponent
=
0
;
int
exponent_diff
=
0
;
if
(
exponent
==
0
)
if
(
exponent
==
0
and
mantissa
!=
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal
has exponent bias 15 while bf8 with
NANOO
has exponent bias 16. It means that there are some
numbers in fp16 denormal but they are bf8 (
NANOO
) normals - smallest bf8 (
NANOO
) normal is
has exponent bias 15 while bf8 with
FNUZ
has exponent bias 16. It means that there are some
numbers in fp16 denormal but they are bf8 (
FNUZ
) normals - smallest bf8 (
FNUZ
) normal is
2^-15. fp16 numbers where exponent==0 (actual exponent -14) and highest bit of mantissa is 1
are bf8 (
NANOO
) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
are bf8 (
FNUZ
) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
1
-
bias
;
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
...
...
@@ -152,10 +152,10 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0)
if
(
act_exponent
<=
f8_denormal_act_exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8
nanoo
mode, denormal exponent is -7, but if the fp32/fp16
For example fp8
FNUZ
mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8
nanoo
*/
So for fp32/fp16, exponent -8 is the cut point to convert to fp8
FNUZ
*/
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
}
else
...
...
@@ -204,7 +204,8 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0)
mantissa
>>=
(
mfmt
-
Wm
);
// above range: quantize to maximum possible float of the same sign
const
int
max_exp
=
(
1
<<
We
)
-
(
NegativeZeroNan
?
1
:
2
);
// for e5m2 case, max_exp is 14, since exp = 15 is reserved for Infs and Nans
const
int
max_exp
=
(
1
<<
We
)
-
((
NegativeZeroNan
or
Wm
==
3
)
?
1
:
2
);
if
(
f8_exponent
>
max_exp
)
{
if
(
Clip
)
...
...
src/py/migraphx_py.cpp
View file @
269ce6d1
...
...
@@ -150,7 +150,7 @@ struct npy_format_descriptor<migraphx::fp8::fp8e4m3fnuz>
static
std
::
string
format
()
{
// following: https://docs.python.org/3/library/struct.html#format-characters
return
"
z
"
;
return
"
B
"
;
}
static
constexpr
auto
name
()
{
return
_
(
"fp8e4m3fnuz"
);
}
};
...
...
test/fp8e4m3fn.cpp
View file @
269ce6d1
...
...
@@ -117,6 +117,27 @@ TEST_CASE(test_fp8_cast_to_float)
})});
}
TEST_CASE
(
test_fp8_cast_from_float
)
{
std
::
unordered_map
<
float
,
uint8_t
>
test_vals
=
{
{{
512
,
0x7e
},
{
-
512
,
0xfe
},
{
448
,
0x7e
},
{
-
448
,
0xfe
},
{
256
,
0x78
},
{
-
256
,
0xf8
},
{
240
,
0x77
},
{
-
240
,
0xf7
},
{
1e-07
,
0x0
},
{
1e+07
,
0x7e
},
{
1
,
0x38
},
{
-
1
,
0xb8
},
{
0.1
,
0x1d
},
{
0.11
,
0x1e
},
{
0.111
,
0x1e
},
{
0.1111
,
0x1e
},
{
-
0.1
,
0x9d
},
{
-
0.11
,
0x9e
},
{
-
0.111
,
0x9e
},
{
-
0.1111
,
0x9e
},
{
0.2
,
0x25
},
{
2
,
0x40
},
{
20
,
0x5a
},
{
200
,
0x74
},
{
-
0.2
,
0xa5
},
{
-
2
,
0xc0
},
{
-
20
,
0xda
},
{
-
200
,
0xf4
},
{
0.5
,
0x30
},
{
-
0.5
,
0xb0
},
{
1.17549e-38
,
0x0
},
{
1.4013e-45
,
0x0
},
{
0.0078125
,
0x4
},
{
-
0.0078125
,
0x84
},
{
0.000976562
,
0x0
},
{
-
0.000976562
,
0x80
},
{
0.000488281
,
0x0
},
{
-
0.000488281
,
0x80
}}};
EXPECT
(
bool
{
std
::
all_of
(
test_vals
.
begin
(),
test_vals
.
end
(),
[](
const
auto
sample
)
{
return
migraphx
::
float_equal
(
migraphx
::
fp8
::
fp8e4m3fn
(
sample
.
first
),
migraphx
::
fp8
::
fp8e4m3fn
(
sample
.
second
,
migraphx
::
fp8
::
fp8e4m3fn
::
from_bits
()));
})});
}
TEST_CASE
(
test_positive_zero
)
{
float
zero
=
0.0
;
...
...
@@ -241,7 +262,7 @@ TEST_CASE(test_binary_ops)
auto
f
=
migraphx
::
fp8
::
fp8e4m3fn
(
-
10.0
);
EXPECT
(
bool
{
e
>
f
});
EXPECT
(
bool
{
f
<
e
});
EXPECT
(
bool
(
f
<=
e
)
);
EXPECT
(
bool
{
f
<=
e
}
);
EXPECT
(
bool
{
e
>=
f
});
EXPECT
(
bool
{
e
<=
e
});
EXPECT
(
bool
{
f
>=
f
});
...
...
test/fp8e4m3fnuz.cpp
View file @
269ce6d1
...
...
@@ -138,6 +138,34 @@ TEST_CASE(test_fp8_cast_to_float)
})});
}
TEST_CASE
(
test_fp8_cast_from_float
)
{
std
::
unordered_map
<
float
,
uint8_t
>
test_vals
=
{{
256
,
0x7f
},
{
-
256
,
0xff
},
{
240
,
0x7f
},
{
-
240
,
0xff
},
{
1e-07
,
0x0
},
{
1e+07
,
0x7f
},
{
1
,
0x40
},
{
-
1
,
0xc0
},
{
0.1
,
0x25
},
{
0.11
,
0x26
},
{
0.111
,
0x26
},
{
0.1111
,
0x26
},
{
-
0.1
,
0xa5
},
{
-
0.11
,
0xa6
},
{
-
0.111
,
0xa6
},
{
-
0.1111
,
0xa6
},
{
0.2
,
0x2d
},
{
2
,
0x48
},
{
20
,
0x62
},
{
200
,
0x7c
},
{
-
0.2
,
0xad
},
{
-
2
,
0xc8
},
{
-
20
,
0xe2
},
{
-
200
,
0xfc
},
{
0.5
,
0x38
},
{
-
0.5
,
0xb8
},
{
1.17549e-38
,
0x0
},
{
1.4013e-45
,
0x0
},
{
0.00390625
,
0x4
},
{
-
0.00390625
,
0x84
},
{
0.00195312
,
0x2
},
{
-
0.00195312
,
0x82
},
{
0.000976562
,
0x1
},
{
-
0.000976562
,
0x81
},
{
0.000488281
,
0x0
},
{
-
0.000488281
,
0x0
}};
EXPECT
(
bool
{
std
::
all_of
(
test_vals
.
begin
(),
test_vals
.
end
(),
[](
const
auto
sample
)
{
return
migraphx
::
float_equal
(
migraphx
::
fp8
::
fp8e4m3fnuz
(
sample
.
first
),
migraphx
::
fp8
::
fp8e4m3fnuz
(
sample
.
second
,
migraphx
::
fp8
::
fp8e4m3fnuz
::
from_bits
()));
})});
}
TEST_CASE
(
test_positive_zero
)
{
float
zero
=
0.0
;
...
...
@@ -256,7 +284,7 @@ TEST_CASE(test_binary_ops)
auto
f
=
migraphx
::
fp8
::
fp8e4m3fnuz
(
-
10.0
);
EXPECT
(
bool
{
e
>
f
});
EXPECT
(
bool
{
f
<
e
});
EXPECT
(
bool
(
f
<=
e
)
);
EXPECT
(
bool
{
f
<=
e
}
);
EXPECT
(
bool
{
e
>=
f
});
EXPECT
(
bool
{
e
<=
e
});
EXPECT
(
bool
{
f
>=
f
});
...
...
test/fp8e5m2.cpp
View file @
269ce6d1
...
...
@@ -314,6 +314,17 @@ TEST_CASE(test_fp8_cast_to_float)
})});
}
TEST_CASE
(
test_fp8_cast_from_float
)
{
std
::
unordered_map
<
float
,
uint8_t
>
test_vals
=
{};
EXPECT
(
bool
{
std
::
all_of
(
test_vals
.
begin
(),
test_vals
.
end
(),
[](
const
auto
sample
)
{
return
migraphx
::
float_equal
(
migraphx
::
fp8
::
fp8e5m2
(
sample
.
first
),
migraphx
::
fp8
::
fp8e5m2
(
sample
.
second
,
migraphx
::
fp8
::
fp8e5m2
::
from_bits
()));
})});
}
TEST_CASE
(
test_positive_zero
)
{
float
zero
=
0.0
;
...
...
@@ -438,7 +449,7 @@ TEST_CASE(test_binary_ops)
auto
f
=
migraphx
::
fp8
::
fp8e5m2
(
-
10.0
);
EXPECT
(
bool
{
e
>
f
});
EXPECT
(
bool
{
f
<
e
});
EXPECT
(
bool
(
f
<=
e
)
);
EXPECT
(
bool
{
f
<=
e
}
);
EXPECT
(
bool
{
e
>=
f
});
EXPECT
(
bool
{
e
<=
e
});
EXPECT
(
bool
{
f
>=
f
});
...
...
test/fp8e5m2fnuz.cpp
View file @
269ce6d1
...
...
@@ -308,6 +308,29 @@ TEST_CASE(test_fp8_cast_to_float)
})});
}
TEST_CASE
(
test_fp8_cast_from_float
)
{
std
::
unordered_map
<
float
,
uint8_t
>
test_vals
=
{
{
57344
,
0x7f
},
{
-
57344
,
0xff
},
{
60000
,
0x7f
},
{
-
60000
,
0xff
},
{
448
,
0x63
},
{
-
448
,
0xe3
},
{
256
,
0x60
},
{
-
256
,
0xe0
},
{
240
,
0x60
},
{
-
240
,
0xe0
},
{
3.05176e-05
,
0x4
},
{
-
3.05176e-05
,
0x84
},
{
1.52588e-05
,
0x2
},
{
-
1.52588e-05
,
0x82
},
{
7.62939e-06
,
0x1
},
{
-
7.62939e-06
,
0x81
},
{
3.81469e-06
,
0x0
},
{
-
3.81469e-06
,
0x0
},
{
1e+07
,
0x7f
},
{
1
,
0x40
},
{
-
1
,
0xc0
},
{
0.1
,
0x32
},
{
0.11
,
0x33
},
{
0.111
,
0x33
},
{
0.1111
,
0x33
},
{
-
0.1
,
0xb2
},
{
-
0.11
,
0xb3
},
{
-
0.111
,
0xb3
},
{
-
0.1111
,
0xb3
},
{
0.2
,
0x36
},
{
2
,
0x44
},
{
20
,
0x51
},
{
200
,
0x5e
},
{
-
0.2
,
0xb6
},
{
-
2
,
0xc4
},
{
-
20
,
0xd1
},
{
-
200
,
0xde
},
{
0.5
,
0x3c
},
{
-
0.5
,
0xbc
},
{
1.17549e-38
,
0x0
},
{
1.4013e-45
,
0x0
},
};
EXPECT
(
bool
{
std
::
all_of
(
test_vals
.
begin
(),
test_vals
.
end
(),
[](
const
auto
sample
)
{
return
migraphx
::
float_equal
(
migraphx
::
fp8
::
fp8e5m2fnuz
(
sample
.
first
),
migraphx
::
fp8
::
fp8e5m2fnuz
(
sample
.
second
,
migraphx
::
fp8
::
fp8e5m2fnuz
::
from_bits
()));
})});
}
TEST_CASE
(
test_positive_zero
)
{
float
zero
=
0.0
;
...
...
@@ -426,7 +449,7 @@ TEST_CASE(test_binary_ops)
auto
f
=
migraphx
::
fp8
::
fp8e5m2fnuz
(
-
10.0
);
EXPECT
(
bool
{
e
>
f
});
EXPECT
(
bool
{
f
<
e
});
EXPECT
(
bool
(
f
<=
e
)
);
EXPECT
(
bool
{
f
<=
e
}
);
EXPECT
(
bool
{
e
>=
f
});
EXPECT
(
bool
{
e
<=
e
});
EXPECT
(
bool
{
f
>=
f
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment