Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
155a2b17
"testing/vscode:/vscode.git/clone" did not exist on "f5d9da46788674b326ace0714c47ad36f39c1de8"
Commit
155a2b17
authored
Nov 10, 2023
by
Umang Yadav
Browse files
move FNUZ as template parameter
parent
9bc18287
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
81 deletions
+42
-81
src/include/migraphx/migraphx_float8.hpp
src/include/migraphx/migraphx_float8.hpp
+42
-81
No files found.
src/include/migraphx/migraphx_float8.hpp
View file @
155a2b17
...
@@ -29,10 +29,6 @@
...
@@ -29,10 +29,6 @@
#pragma clang diagnostic ignored "-Wc++20-extensions"
#pragma clang diagnostic ignored "-Wc++20-extensions"
#endif // __clang__
#endif // __clang__
#ifndef MIGRAPHX_FP8_FNUZ
#define MIGRAPHX_FP8_FNUZ true
#endif // MIGRAPHX_FP8_FNUZ
// We are clipping in down conversion by default
// We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
...
@@ -73,10 +69,10 @@ enum class f8_type
...
@@ -73,10 +69,10 @@ enum class f8_type
fp8
=
1
// s1e4m3
fp8
=
1
// s1e4m3
};
};
template
<
typename
T
>
template
<
typename
T
,
bool
FNUZ
=
true
>
class
numeric_limits
;
class
numeric_limits
;
template
<
migraphx_fp8
::
f8_type
T
=
migraphx_fp8
::
f8_type
::
fp8
>
template
<
migraphx_fp8
::
f8_type
T
=
migraphx_fp8
::
f8_type
::
fp8
,
bool
FNUZ
=
true
>
struct
float8
struct
float8
{
{
uint8_t
data
=
0x00
;
uint8_t
data
=
0x00
;
...
@@ -100,11 +96,11 @@ struct float8
...
@@ -100,11 +96,11 @@ struct float8
{
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx_f8_impl
::
data
=
migraphx_f8_impl
::
cast_to_f8
<
3
,
4
,
float
,
MIGRAPHX_FP8_
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx_fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
v
,
(
rm
==
migraphx_fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx_f8_impl
::
data
=
migraphx_f8_impl
::
cast_to_f8
<
3
,
4
,
float
,
MIGRAPHX_FP8_
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx_fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
v
,
(
rm
==
migraphx_fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
}
}
...
@@ -112,11 +108,11 @@ struct float8
...
@@ -112,11 +108,11 @@ struct float8
{
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx_f8_impl
::
data
=
migraphx_f8_impl
::
cast_to_f8
<
2
,
5
,
float
,
MIGRAPHX_FP8_
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx_fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
v
,
(
rm
==
migraphx_fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx_f8_impl
::
data
=
migraphx_f8_impl
::
cast_to_f8
<
2
,
5
,
float
,
MIGRAPHX_FP8_
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx_fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
v
,
(
rm
==
migraphx_fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
#endif // rocblas_F8_downcast_clipping}
#endif // rocblas_F8_downcast_clipping}
}
}
...
@@ -126,16 +122,14 @@ struct float8
...
@@ -126,16 +122,14 @@ struct float8
{
{
if
constexpr
(
T
==
migraphx_fp8
::
f8_type
::
fp8
)
if
constexpr
(
T
==
migraphx_fp8
::
f8_type
::
fp8
)
{
{
return
migraphx_f8_impl
::
return
migraphx_f8_impl
::
cast_from_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
cast_from_f8
<
3
,
4
,
float
,
MIGRAPHX_FP8_FNUZ
/*negative_zero_nan*/
>
(
data
);
}
// else
}
// else
return
migraphx_f8_impl
::
cast_from_f8
<
2
,
5
,
float
,
MIGRAPHX_FP8_FNUZ
/*negative_zero_nan*/
>
(
return
migraphx_f8_impl
::
cast_from_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
data
);
}
}
inline
constexpr
bool
is_zero
()
const
inline
constexpr
bool
is_zero
()
const
{
{
if
constexpr
(
MIGRAPHX_FP8_
FNUZ
)
if
constexpr
(
FNUZ
)
{
{
return
data
==
0x00
;
return
data
==
0x00
;
}
}
...
@@ -147,7 +141,7 @@ struct float8
...
@@ -147,7 +141,7 @@ struct float8
inline
constexpr
bool
is_nan
()
const
inline
constexpr
bool
is_nan
()
const
{
{
if
constexpr
(
MIGRAPHX_FP8_
FNUZ
)
if
constexpr
(
FNUZ
)
{
{
return
data
==
0x80
;
return
data
==
0x80
;
}
}
...
@@ -170,7 +164,7 @@ struct float8
...
@@ -170,7 +164,7 @@ struct float8
inline
constexpr
bool
is_inf
()
const
inline
constexpr
bool
is_inf
()
const
{
{
if
constexpr
(
MIGRAPHX_FP8_
FNUZ
)
if
constexpr
(
FNUZ
)
{
{
return
data
==
0x80
;
return
data
==
0x80
;
}
}
...
@@ -218,7 +212,7 @@ struct float8
...
@@ -218,7 +212,7 @@ struct float8
inline
constexpr
bool
operator
==
(
const
float8
&
rhs
)
const
inline
constexpr
bool
operator
==
(
const
float8
&
rhs
)
const
{
{
if
((
rhs
.
is_zero
()
and
this
->
is_zero
())
or
if
((
rhs
.
is_zero
()
and
this
->
is_zero
())
or
(
fabs
(
rhs
-
*
this
)
<
migraphx_fp8
::
numeric_limits
<
float8
<
T
>>::
epsilon
()))
(
fabs
(
rhs
-
*
this
)
<
migraphx_fp8
::
numeric_limits
<
float8
<
T
,
FNUZ
>>::
epsilon
()))
return
true
;
return
true
;
else
if
(
rhs
.
is_nan
()
or
rhs
.
is_inf
()
or
this
->
is_nan
()
or
this
->
is_inf
())
else
if
(
rhs
.
is_nan
()
or
rhs
.
is_inf
()
or
this
->
is_nan
()
or
this
->
is_inf
())
return
false
;
return
false
;
...
@@ -289,123 +283,90 @@ constexpr T F8_Lowest()
...
@@ -289,123 +283,90 @@ constexpr T F8_Lowest()
return
T
{
0xFF
,
T
::
from_bits
()};
return
T
{
0xFF
,
T
::
from_bits
()};
}
}
using
fp8e4m3fnuz
=
float8
<
migraphx_fp8
::
f8_type
::
fp8
>
;
using
fp8e4m3fn
=
float8
<
migraphx_fp8
::
f8_type
::
fp8
,
false
>
;
using
fp8e5m2
=
float8
<
migraphx_fp8
::
f8_type
::
bf8
,
false
>
;
using
fp8e4m3fnuz
=
float8
<
migraphx_fp8
::
f8_type
::
fp8
,
true
>
;
using
fp8e5m2fnuz
=
float8
<
migraphx_fp8
::
f8_type
::
bf8
,
true
>
;
template
<
>
template
<
>
class
numeric_limits
<
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
fp8
>
>
class
numeric_limits
<
fp8e4m3fnuz
>
{
{
public:
public:
// TODO :figure out epsilon in Hex to make it constexpr
static
constexpr
fp8e4m3fnuz
epsilon
()
static
constexpr
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
fp8
>
epsilon
()
{
{
return
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
fp8
>
(
return
fp8e4m3fnuz
(
0x28
,
migraphx_fp8
::
float8
<>::
from_bits
());
0x28
,
migraphx_fp8
::
float8
<>::
from_bits
());
}
}
static
constexpr
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
fp8
>
quiet_NaN
()
static
constexpr
fp8e4m3fnuz
quiet_NaN
()
{
return
fp8e4m3fnuz
(
0x80
,
fp8e4m3fnuz
::
from_bits
());
}
{
return
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
fp8
>
(
MIGRAPHX_FP8_FNUZ
?
0x80
:
0x7F
,
migraphx_fp8
::
float8
<>::
from_bits
());
}
static
constexpr
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
fp8
>
max
()
static
constexpr
fp8e4m3fnuz
max
()
{
return
migraphx_fp8
::
F8_Max
<
fp8e4m3fnuz
>
();
}
{
return
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
fp8
>>
();
}
// TODO figure out Hex value
// TODO figure out Hex value
static
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
fp8
>
min
()
static
fp8e4m3fnuz
min
()
{
{
return
static_cast
<
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
fp8
>>
(
-
1.0
f
)
*
return
static_cast
<
fp8e4m3fnuz
>
(
-
1.0
f
)
*
migraphx_fp8
::
F8_Max
<
fp8e4m3fnuz
>
();
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
fp8
>>
();
}
}
static
constexpr
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
fp8
>
lowest
()
static
constexpr
fp8e4m3fnuz
lowest
()
{
return
migraphx_fp8
::
F8_Lowest
<
fp8e4m3fnuz
>
();
}
{
return
migraphx_fp8
::
F8_Lowest
<
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
fp8
>>
();
}
static
constexpr
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
fp8
>
infinity
()
static
constexpr
fp8e4m3fnuz
infinity
()
{
return
fp8e4m3fnuz
(
0x80
,
fp8e4m3fnuz
::
from_bits
());
}
{
return
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
fp8
>
(
MIGRAPHX_FP8_FNUZ
?
0x80
:
0x7F
,
migraphx_fp8
::
float8
<>::
from_bits
());
}
};
};
template
<
>
template
<
>
class
numeric_limits
<
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>
>
class
numeric_limits
<
fp8e5m2fnuz
>
{
{
public:
public:
static
constexpr
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>
epsilon
()
static
constexpr
fp8e5m2fnuz
epsilon
()
{
return
fp8e5m2fnuz
(
0x34
,
fp8e5m2fnuz
::
from_bits
());
}
{
return
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>
(
0x34
,
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>::
from_bits
());
}
static
constexpr
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>
quiet_NaN
()
static
constexpr
fp8e5m2fnuz
quiet_NaN
()
{
return
fp8e5m2fnuz
(
0x80
,
fp8e5m2fnuz
::
from_bits
());
}
{
return
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>
(
MIGRAPHX_FP8_FNUZ
?
0x80
:
0x7d
,
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>::
from_bits
());
}
static
constexpr
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>
max
()
static
constexpr
fp8e5m2fnuz
max
()
{
{
return
static_cast
<
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>>
(
return
static_cast
<
fp8e5m2fnuz
>
(
migraphx_fp8
::
F8_Max
<
fp8e5m2fnuz
>
());
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>>
());
}
}
// TODO figure out constexpr value
// TODO figure out constexpr value
static
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>
min
()
static
fp8e5m2fnuz
min
()
{
{
return
static_cast
<
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>>
(
float
(
-
1.0
f
))
*
return
static_cast
<
fp8e5m2fnuz
>
(
float
(
-
1.0
f
))
*
migraphx_fp8
::
F8_Max
<
fp8e5m2fnuz
>
();
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>>
();
}
static
constexpr
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>
lowest
()
{
return
migraphx_fp8
::
F8_Lowest
<
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>>
();
}
}
static
constexpr
fp8e5m2fnuz
lowest
()
{
return
migraphx_fp8
::
F8_Lowest
<
fp8e5m2fnuz
>
();
}
static
constexpr
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>
infinity
()
static
constexpr
fp8e5m2fnuz
infinity
()
{
return
fp8e5m2fnuz
(
0x80
,
fp8e5m2fnuz
::
from_bits
());
}
{
return
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>
(
MIGRAPHX_FP8_FNUZ
?
0x80
:
0x7c
,
migraphx_fp8
::
float8
<
migraphx_fp8
::
f8_type
::
bf8
>::
from_bits
());
}
};
};
}
// namespace migraphx_fp8
}
// namespace migraphx_fp8
// =================================================================================================
// =================================================================================================
// define numeric limits for the new data type
// define numeric limits for the new data type
namespace
std
{
namespace
std
{
inline
bool
isfinite
(
migraphx_fp8
::
f
loat8
<
migraphx_fp8
::
f8_type
::
fp8
>
x
)
// NOLINT
inline
bool
isfinite
(
migraphx_fp8
::
f
p8e4m3fnuz
x
)
// NOLINT
{
{
return
x
.
is_inf
();
return
x
.
is_inf
();
}
}
inline
bool
isfinite
(
migraphx_fp8
::
f
loat8
<
migraphx_fp8
::
f8_type
::
bf8
>
x
)
// NOLINT
inline
bool
isfinite
(
migraphx_fp8
::
f
p8e5m2fnuz
x
)
// NOLINT
{
{
return
x
.
is_inf
();
return
x
.
is_inf
();
}
}
inline
bool
isnan
(
migraphx_fp8
::
f
loat8
<
migraphx_fp8
::
f8_type
::
fp8
>
x
)
// NOLINT
inline
bool
isnan
(
migraphx_fp8
::
f
p8e4m3fnuz
x
)
// NOLINT
{
{
return
x
.
is_nan
();
return
x
.
is_nan
();
}
}
inline
bool
isnan
(
migraphx_fp8
::
f
loat8
<
migraphx_fp8
::
f8_type
::
bf8
>
x
)
// NOLINT
inline
bool
isnan
(
migraphx_fp8
::
f
p8e5m2fnuz
x
)
// NOLINT
{
{
return
x
.
is_nan
();
return
x
.
is_nan
();
}
}
template
<
>
template
<
>
class
numeric_limits
<
migraphx_fp8
::
f
loat8
<
migraphx_fp8
::
f8_type
::
fp8
>
>
class
numeric_limits
<
migraphx_fp8
::
f
p8e4m3fnuz
>
:
public
migraphx_fp8
::
numeric_limits
<
migraphx_fp8
::
f
loat8
<
migraphx_fp8
::
f8_type
::
fp8
>
>
:
public
migraphx_fp8
::
numeric_limits
<
migraphx_fp8
::
f
p8e4m3fnuz
>
{
{
};
};
template
<
>
template
<
>
class
numeric_limits
<
migraphx_fp8
::
f
loat8
<
migraphx_fp8
::
f8_type
::
bf8
>
>
class
numeric_limits
<
migraphx_fp8
::
f
p8e5m2fnuz
>
:
public
migraphx_fp8
::
numeric_limits
<
migraphx_fp8
::
f
loat8
<
migraphx_fp8
::
f8_type
::
bf8
>
>
:
public
migraphx_fp8
::
numeric_limits
<
migraphx_fp8
::
f
p8e5m2fnuz
>
{
{
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment