Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
60092324
Commit
60092324
authored
Nov 14, 2023
by
Umang Yadav
Browse files
add tests
parent
12aac372
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
154 additions
and
31 deletions
+154
-31
src/include/migraphx/float8.hpp
src/include/migraphx/float8.hpp
+66
-31
test/fp8e4m3fn.cpp
test/fp8e4m3fn.cpp
+22
-0
test/fp8e4m3fnuz.cpp
test/fp8e4m3fnuz.cpp
+22
-0
test/fp8e5m2.cpp
test/fp8e5m2.cpp
+22
-0
test/fp8e5m2fnuz.cpp
test/fp8e5m2fnuz.cpp
+22
-0
No files found.
src/include/migraphx/float8.hpp
View file @
60092324
...
...
@@ -227,49 +227,84 @@ struct float8
}
};
// https://onnx.ai/onnx/technical/float8.html
using
fp8e4m3fn
=
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
,
false
>
;
using
fp8e5m2
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
false
>
;
using
fp8e4m3fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
,
true
>
;
using
fp8e5m2fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
true
>
;
/*
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \
inline constexpr U operator binary_op(const T& lhs, const T& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
// TODO: these should return floats for binary ops
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP_GEN_FOR(T) \
MIGRAPHX_FP8_BINARY_OP(*, T, T) \
MIGRAPHX_FP8_BINARY_OP(-, T, T) \
MIGRAPHX_FP8_BINARY_OP(/, T, T) \
MIGRAPHX_FP8_BINARY_OP(+, T, T) \
MIGRAPHX_FP8_BINARY_OP(==, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<, T, bool) \
MIGRAPHX_FP8_BINARY_OP(!=, T, bool)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e5m2)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e4m3fn)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e5m2fnuz)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e4m3fnuz)
*/
// Special operator overloading
template
<
migraphx
::
fp8
::
f8_type
T
>
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
migraphx
::
fp8
::
float8
<
T
>&
rhs
)
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
fp8e4m3fnuz
&
rhs
)
{
return
os
<<
static_cast
<
float
>
(
rhs
);
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
template <migraphx::fp8::f8_type T> \
inline constexpr U operator binary_op(const migraphx::fp8::float8<T>& lhs, \
const migraphx::fp8::float8<T>& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
inline
fp8e4m3fnuz
fabs
(
fp8e4m3fnuz
v
)
{
v
.
data
=
v
.
data
&
0x7f
;
// NOLINT
return
v
;
}
// Special operator overloading
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
fp8e4m3fn
&
rhs
)
{
return
os
<<
static_cast
<
float
>
(
rhs
);
}
// TODO: these should return floats
MIGRAPHX_FP8_BINARY_OP
(
*
,
migraphx
::
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
-
,
migraphx
::
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
/
,
migraphx
::
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
+
,
migraphx
::
fp8
::
float8
<
T
>
)
// TODO: Comparison ops shouldn't convert to float, need to check if need to take care of rounding
// effects.
MIGRAPHX_FP8_BINARY_OP
(
==
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
>=
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
<=
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
>
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
<
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
!=
,
bool
)
template
<
migraphx
::
fp8
::
f8_type
T
>
inline
migraphx
::
fp8
::
float8
<
T
>
fabs
(
migraphx
::
fp8
::
float8
<
T
>
v
)
inline
fp8e4m3fn
fabs
(
fp8e4m3fn
v
)
{
v
.
data
=
v
.
data
&
0x7f
;
// NOLINT
return
v
;
}
//
https://onnx.ai/onnx/technical/float8.html
using
fp8e4m3fn
=
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
,
false
>
;
using
fp8e5m2
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
false
>
;
using
fp8e4m3fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
,
true
>
;
using
fp8e5m2fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
true
>
;
//
Special operator overloading
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
fp8e5m2fnuz
&
rhs
)
{
return
os
<<
static_cast
<
float
>
(
rhs
)
;
}
inline
fp8e5m2fnuz
fabs
(
fp8e5m2fnuz
v
)
{
v
.
data
=
v
.
data
&
0x7f
;
// NOLINT
return
v
;
}
// Special operator overloading
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
fp8e5m2
&
rhs
)
{
return
os
<<
static_cast
<
float
>
(
rhs
);
}
inline
fp8e5m2
fabs
(
fp8e5m2
v
)
{
v
.
data
=
v
.
data
&
0x7f
;
// NOLINT
return
v
;
}
template
<
>
class
numeric_limits
<
fp8e4m3fnuz
>
{
...
...
test/fp8e4m3fn.cpp
View file @
60092324
...
...
@@ -226,4 +226,26 @@ TEST_CASE(test_no_infinity)
EXPECT
(
not
bool
{
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
has_infinity
});
}
TEST_CASE
(
test_binary_ops
)
{
auto
a
=
migraphx
::
fp8
::
fp8e5m2
(
-
1.0
);
auto
b
=
migraphx
::
fp8
::
fp8e5m2
(
1.0
);
auto
c
=
migraphx
::
fp8
::
fp8e5m2
(
0.0
);
auto
d
=
migraphx
::
fp8
::
fp8e5m2
(
-
0.0
);
EXPECT
(
migraphx
::
float_equal
((
c
+
d
),
c
));
EXPECT
(
migraphx
::
float_equal
((
c
+
d
),
d
));
EXPECT
(
migraphx
::
float_equal
((
a
+
b
),
c
));
EXPECT
(
migraphx
::
float_equal
((
a
+
b
),
d
));
auto
e
=
migraphx
::
fp8
::
fp8e5m2
(
10.0
);
auto
f
=
migraphx
::
fp8
::
fp8e5m2
(
-
10.0
);
EXPECT
(
bool
{
e
>
f
});
EXPECT
(
bool
{
f
<
e
});
EXPECT
(
bool
(
f
<=
e
));
EXPECT
(
bool
{
e
>=
f
});
EXPECT
(
bool
{
e
<=
e
});
EXPECT
(
bool
{
f
>=
f
});
EXPECT
(
not
migraphx
::
float_equal
(
f
,
e
));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/fp8e4m3fnuz.cpp
View file @
60092324
...
...
@@ -241,4 +241,26 @@ TEST_CASE(test_no_infinity)
EXPECT
(
not
bool
{
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
has_infinity
});
}
TEST_CASE
(
test_binary_ops
)
{
auto
a
=
migraphx
::
fp8
::
fp8e5m2
(
-
1.0
);
auto
b
=
migraphx
::
fp8
::
fp8e5m2
(
1.0
);
auto
c
=
migraphx
::
fp8
::
fp8e5m2
(
0.0
);
auto
d
=
migraphx
::
fp8
::
fp8e5m2
(
-
0.0
);
EXPECT
(
migraphx
::
float_equal
((
c
+
d
),
c
));
EXPECT
(
migraphx
::
float_equal
((
c
+
d
),
d
));
EXPECT
(
migraphx
::
float_equal
((
a
+
b
),
c
));
EXPECT
(
migraphx
::
float_equal
((
a
+
b
),
d
));
auto
e
=
migraphx
::
fp8
::
fp8e5m2
(
10.0
);
auto
f
=
migraphx
::
fp8
::
fp8e5m2
(
-
10.0
);
EXPECT
(
bool
{
e
>
f
});
EXPECT
(
bool
{
f
<
e
});
EXPECT
(
bool
(
f
<=
e
));
EXPECT
(
bool
{
e
>=
f
});
EXPECT
(
bool
{
e
<=
e
});
EXPECT
(
bool
{
f
>=
f
});
EXPECT
(
not
migraphx
::
float_equal
(
f
,
e
));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/fp8e5m2.cpp
View file @
60092324
...
...
@@ -422,4 +422,26 @@ TEST_CASE(test_isfinite)
EXPECT
(
not
std
::
isfinite
(
migraphx
::
fp8
::
fp8e5m2
(
0xFC
,
migraphx
::
fp8
::
fp8e5m2
::
from_bits
())));
}
TEST_CASE
(
test_binary_ops
)
{
auto
a
=
migraphx
::
fp8
::
fp8e5m2
(
-
1.0
);
auto
b
=
migraphx
::
fp8
::
fp8e5m2
(
1.0
);
auto
c
=
migraphx
::
fp8
::
fp8e5m2
(
0.0
);
auto
d
=
migraphx
::
fp8
::
fp8e5m2
(
-
0.0
);
EXPECT
(
migraphx
::
float_equal
((
c
+
d
),
c
));
EXPECT
(
migraphx
::
float_equal
((
c
+
d
),
d
));
EXPECT
(
migraphx
::
float_equal
((
a
+
b
),
c
));
EXPECT
(
migraphx
::
float_equal
((
a
+
b
),
d
));
auto
e
=
migraphx
::
fp8
::
fp8e5m2
(
10.0
);
auto
f
=
migraphx
::
fp8
::
fp8e5m2
(
-
10.0
);
EXPECT
(
bool
{
e
>
f
});
EXPECT
(
bool
{
f
<
e
});
EXPECT
(
bool
(
f
<=
e
));
EXPECT
(
bool
{
e
>=
f
});
EXPECT
(
bool
{
e
<=
e
});
EXPECT
(
bool
{
f
>=
f
});
EXPECT
(
not
migraphx
::
float_equal
(
f
,
e
));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/fp8e5m2fnuz.cpp
View file @
60092324
...
...
@@ -411,4 +411,26 @@ TEST_CASE(test_no_infinity)
EXPECT
(
not
bool
{
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
has_infinity
});
}
TEST_CASE
(
test_binary_ops
)
{
auto
a
=
migraphx
::
fp8
::
fp8e5m2
(
-
1.0
);
auto
b
=
migraphx
::
fp8
::
fp8e5m2
(
1.0
);
auto
c
=
migraphx
::
fp8
::
fp8e5m2
(
0.0
);
auto
d
=
migraphx
::
fp8
::
fp8e5m2
(
-
0.0
);
EXPECT
(
migraphx
::
float_equal
((
c
+
d
),
c
));
EXPECT
(
migraphx
::
float_equal
((
c
+
d
),
d
));
EXPECT
(
migraphx
::
float_equal
((
a
+
b
),
c
));
EXPECT
(
migraphx
::
float_equal
((
a
+
b
),
d
));
auto
e
=
migraphx
::
fp8
::
fp8e5m2
(
10.0
);
auto
f
=
migraphx
::
fp8
::
fp8e5m2
(
-
10.0
);
EXPECT
(
bool
{
e
>
f
});
EXPECT
(
bool
{
f
<
e
});
EXPECT
(
bool
(
f
<=
e
));
EXPECT
(
bool
{
e
>=
f
});
EXPECT
(
bool
{
e
<=
e
});
EXPECT
(
bool
{
f
>=
f
});
EXPECT
(
not
migraphx
::
float_equal
(
f
,
e
));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment