Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f9542d5b
Commit
f9542d5b
authored
Nov 20, 2023
by
Umang Yadav
Browse files
Put numeric_max and numeeric lowest into float8
parent
836e201e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
24 deletions
+24
-24
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
+19
-0
src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp
...gets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp
+3
-16
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+1
-0
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+1
-8
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
View file @
f9542d5b
...
...
@@ -33,6 +33,7 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/float8_impl.hpp>
#include <migraphx/kernels/type_traits.hpp>
namespace
migraphx
{
namespace
fp8
{
...
...
@@ -538,6 +539,24 @@ class numeric_limits<fp8e5m2>
};
}
// namespace fp8
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_MIN_MAX(T) \
template <> \
constexpr T numeric_max<T, void>() \
{ \
return fp8::numeric_limits<T>::max(); \
} \
template <> \
constexpr T numeric_lowest<T>() \
{ \
return fp8::numeric_limits<T>::lowest(); \
}
MIGRAPHX_FP8_MIN_MAX
(
fp8
::
fp8e4m3fnuz
);
MIGRAPHX_FP8_MIN_MAX
(
fp8
::
fp8e5m2fnuz
);
MIGRAPHX_FP8_MIN_MAX
(
fp8
::
fp8e4m3fn
);
MIGRAPHX_FP8_MIN_MAX
(
fp8
::
fp8e5m2
);
}
// namespace migraphx
// =================================================================================================
#if defined(__clang__)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp
View file @
f9542d5b
...
...
@@ -23,26 +23,13 @@
#ifndef MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP
#define MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP
#include <migraphx/kernels/bit_cast.hpp>
#include <migraphx/kernels/type_traits.hpp>
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreserved-identifier"
#endif
namespace
migraphx
{
namespace
detail
{
template
<
bool
B
,
class
T
,
class
F
>
struct
conditional
{
using
type
=
T
;
};
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
using
type
=
F
;
};
}
// namespace detail
namespace
fp8
{
namespace
impl
{
...
...
@@ -58,7 +45,7 @@ __device__ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng
static_assert
(
is_float
or
is_half
,
"Only float can be cast to f8"
);
const
uint32_t
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
typename
detail
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>
::
type
x
;
typename
migraphx
::
conditional
_t
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>
x
;
if
constexpr
(
sizeof
(
T
)
==
4
)
x
=
migraphx
::
bit_cast
<
uint32_t
>
(
f_x
);
...
...
@@ -304,7 +291,7 @@ __device__ constexpr T cast_from_f8(uint8_t x)
else
if
(
Wm
==
3
and
(
x
==
0x7F
or
x
==
0xFF
))
return
f_nan
;
}
typename
detail
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>
::
type
retval
;
typename
migraphx
::
conditional
_t
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>
retval
;
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
We
-
1
))
+
1
-
(
NegativeZeroNan
?
1
:
0
);
// NOLINT
...
...
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
f9542d5b
...
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_KERNELS_MATH_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/float8.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
f9542d5b
...
...
@@ -26,7 +26,6 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/float8.hpp>
namespace
migraphx
{
...
...
@@ -231,8 +230,7 @@ constexpr unsigned long int_max(unsigned long n)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_integral
<
T
>{}
or
is_floating_point
<
T
>
{}
or
is_same
<
T
,
migraphx
::
half
>
{}
or
is_same
<
T
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
{})
>
is_same
<
T
,
migraphx
::
half
>
{})
>
constexpr
T
numeric_max
()
{
if
constexpr
(
is_integral
<
T
>
{})
...
...
@@ -248,9 +246,6 @@ constexpr T numeric_max()
return
__FLT_MAX__
;
else
if
constexpr
(
is_same
<
T
,
migraphx
::
half
>
{})
return
__FLT16_MAX__
;
// TODO: Do it generically for all fp8 types
else
if
constexpr
(
is_same
<
T
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
{})
return
migraphx
::
fp8
::
numeric_limits
<
T
>::
max
();
else
return
0
;
}
...
...
@@ -265,8 +260,6 @@ constexpr T numeric_lowest()
else
return
-
numeric_max
<
T
>
()
-
1
;
}
else
if
constexpr
(
is_same
<
T
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
{})
return
migraphx
::
fp8
::
numeric_limits
<
T
>::
lowest
();
else
{
return
-
numeric_max
<
T
>
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment