Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
85ba819b
Commit
85ba819b
authored
Nov 17, 2023
by
Umang Yadav
Browse files
constructor from float works with constexpr
parent
b36f72d3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
48 deletions
+34
-48
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
+34
-48
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
View file @
85ba819b
...
@@ -75,7 +75,7 @@ struct float8
...
@@ -75,7 +75,7 @@ struct float8
// device specific optimized F8 down-conversion code
// device specific optimized F8 down-conversion code
template
<
bool
stochastic_rounding
=
false
>
template
<
bool
stochastic_rounding
=
false
>
static
MIGRAPHX_HIP_DEVICE
uint8_t
cast_to_f8_from_f32
(
float
v
,
uint32_t
rng
=
0
)
static
constexpr
MIGRAPHX_HIP_DEVICE
uint8_t
cast_to_f8_from_f32
(
float
v
,
uint32_t
rng
=
0
)
{
{
uint8_t
i8data
;
uint8_t
i8data
;
union
union
...
@@ -135,7 +135,7 @@ struct float8
...
@@ -135,7 +135,7 @@ struct float8
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// NOTE: ON-DEVICE... always optimal bias
// NOTE: ON-DEVICE... always optimal bias
explicit
MIGRAPHX_HIP_DEVICE
explicit
constexpr
MIGRAPHX_HIP_DEVICE
float8
(
float
v
,
float8
(
float
v
,
migraphx
::
fp8
::
rounding_mode
rm
=
migraphx
::
fp8
::
rounding_mode
::
standard
,
migraphx
::
fp8
::
rounding_mode
rm
=
migraphx
::
fp8
::
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
uint32_t
rng
=
0
)
...
@@ -176,7 +176,7 @@ struct float8
...
@@ -176,7 +176,7 @@ struct float8
data
=
migraphx
::
fp8
::
impl
::
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
#endif //
rocblas
_F8_
downcast_clipping
}
#endif //
MIGRAPHX
_F
P
8_
DOWNCAST_CLIPPING
}
}
}
}
}
...
@@ -314,58 +314,44 @@ struct float8
...
@@ -314,58 +314,44 @@ struct float8
}
}
};
};
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
template <migraphx::fp8::f8_type T> \
inline constexpr U MIGRAPHX_HIP_DEVICE operator binary_op(const migraphx::fp8::float8<T>& lhs, \
const migraphx::fp8::float8<T>& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
// TODO: these should return floats
MIGRAPHX_FP8_BINARY_OP
(
*
,
migraphx
::
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
-
,
migraphx
::
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
/
,
migraphx
::
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
+
,
migraphx
::
fp8
::
float8
<
T
>
)
// TODO: Comparison ops shouldn't convert to float, maybe need to take care of rounding effects.
MIGRAPHX_FP8_BINARY_OP
(
==
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
>=
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
<=
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
>
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
<
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
!=
,
bool
)
// https://onnx.ai/onnx/technical/float8.html
// https://onnx.ai/onnx/technical/float8.html
using
fp8e4m3fn
=
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
,
false
>
;
using
fp8e4m3fn
=
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
,
false
>
;
using
fp8e5m2
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
false
>
;
using
fp8e5m2
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
false
>
;
using
fp8e4m3fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
,
true
>
;
using
fp8e4m3fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
,
true
>
;
using
fp8e5m2fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
true
>
;
using
fp8e5m2fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
true
>
;
;
inline
MIGRAPHX_HIP_DEVICE
fp8e4m3fnuz
fabs
(
fp8e4m3fnuz
v
)
{
v
.
data
=
v
.
data
&
0x7f
;
return
v
;
}
inline
MIGRAPHX_HIP_DEVICE
fp8e4m3fn
fabs
(
fp8e4m3fn
v
)
{
v
.
data
=
v
.
data
&
0x7f
;
return
v
;
}
inline
MIGRAPHX_HIP_DEVICE
fp8e5m2fnuz
fabs
(
fp8e5m2fnuz
v
)
// NOLINTNEXTLINE
{
#define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \
v
.
data
=
v
.
data
&
0x7f
;
inline constexpr U MIGRAPHX_HIP_DEVICE operator binary_op(const T& lhs, const T& rhs) \
return
v
;
{ \
}
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
inline
MIGRAPHX_HIP_DEVICE
fp8e5m2
fabs
(
fp8e5m2
v
)
// NOLINTNEXTLINE
{
#define MIGRAPHX_FP8_UNARY_OP(unary_op, T) \
v
.
data
=
v
.
data
&
0x7f
;
inline constexpr MIGRAPHX_HIP_DEVICE T unary_op(T v) \
return
v
;
{ \
}
v.data = v.data & 0x7f; \
return v; \
}
#define MIGRAPHX_FP8_GEN_OP_OVERLOADS(T) \
MIGRAPHX_FP8_BINARY_OP(*, T, T) \
MIGRAPHX_FP8_BINARY_OP(-, T, T) \
MIGRAPHX_FP8_BINARY_OP(/, T, T) \
MIGRAPHX_FP8_BINARY_OP(+, T, T) \
MIGRAPHX_FP8_BINARY_OP(==, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<, T, bool) \
MIGRAPHX_FP8_BINARY_OP(!=, T, bool) \
MIGRAPHX_FP8_UNARY_OP(fabs, T)
MIGRAPHX_FP8_GEN_OP_OVERLOADS
(
fp8e5m2
)
MIGRAPHX_FP8_GEN_OP_OVERLOADS
(
fp8e5m2fnuz
)
MIGRAPHX_FP8_GEN_OP_OVERLOADS
(
fp8e4m3fn
)
MIGRAPHX_FP8_GEN_OP_OVERLOADS
(
fp8e4m3fnuz
)
template
<
>
template
<
>
class
numeric_limits
<
fp8e4m3fnuz
>
class
numeric_limits
<
fp8e4m3fnuz
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment