Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
60942349
Commit
60942349
authored
Nov 17, 2023
by
Umang Yadav
Browse files
Make FNUZ template param and add numeric limits
parent
d7339e8a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
105 additions
and
69 deletions
+105
-69
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
+105
-69
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
View file @
60942349
...
...
@@ -46,10 +46,6 @@
#define MIGRAPHX_HIP_DEVICE __device__
#ifndef MIGRAPHX_FP8_FNUZ
#define MIGRAPHX_FP8_FNUZ true
#endif // MIGRAPHX_FP8_FNUZ
// We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
#if defined(MIGRAPHX_JIT_USE_HIPRTC)
...
...
@@ -90,14 +86,14 @@ enum class f8_type
template
<
typename
T
>
class
numeric_limits
;
template
<
migraphx
::
fp8
::
f8_type
T
=
migraphx
::
fp8
::
f8_type
::
fp8
>
template
<
migraphx
::
fp8
::
f8_type
T
=
migraphx
::
fp8
::
f8_type
::
fp8
,
bool
FNUZ
=
true
>
struct
float8
{
uint8_t
data
;
// default constructor
MIGRAPHX_HIP_HOST_DEVICE
constexpr
float8
()
=
default
;
// default copy constructor
MIGRAPHX_HIP_HOST_DEVICE
constexpr
float8
(
const
float8
<
T
>
&
y
)
=
default
;
MIGRAPHX_HIP_HOST_DEVICE
constexpr
float8
(
const
float8
&
y
)
=
default
;
struct
from_bits_t
{
};
...
...
@@ -195,11 +191,11 @@ struct float8
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
3
,
4
,
float
,
MIGRAPHX_FP8_
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
3
,
4
,
float
,
MIGRAPHX_FP8_
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
}
...
...
@@ -207,11 +203,11 @@ struct float8
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
2
,
5
,
float
,
MIGRAPHX_FP8_
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
2
,
5
,
float
,
MIGRAPHX_FP8_
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
#endif // rocblas_F8_downcast_clipping}
}
...
...
@@ -278,11 +274,9 @@ struct float8
{
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
{
return
migraphx
::
fp8
::
impl
::
cast_from_f8
<
3
,
4
,
float
,
MIGRAPHX_FP8_FNUZ
/*negative_zero_nan*/
>
(
data
);
return
migraphx
::
fp8
::
impl
::
cast_from_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
}
// else
return
migraphx
::
fp8
::
impl
::
cast_from_f8
<
2
,
5
,
float
,
MIGRAPHX_FP8_FNUZ
/*negative_zero_nan*/
>
(
data
);
return
migraphx
::
fp8
::
impl
::
cast_from_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
}
/*
...
...
@@ -296,7 +290,7 @@ struct float8
// check for zero
inline
MIGRAPHX_HIP_HOST_DEVICE
constexpr
bool
is_zero
()
const
{
if
constexpr
(
MIGRAPHX_FP8_
FNUZ
)
if
constexpr
(
FNUZ
)
{
return
data
==
0x00
;
}
...
...
@@ -309,7 +303,7 @@ struct float8
// check for nan
inline
MIGRAPHX_HIP_HOST_DEVICE
constexpr
bool
is_nan
()
const
{
if
constexpr
(
MIGRAPHX_FP8_
FNUZ
)
if
constexpr
(
FNUZ
)
{
return
data
==
0x80
;
}
...
...
@@ -333,7 +327,7 @@ struct float8
// check for inf
inline
MIGRAPHX_HIP_HOST_DEVICE
constexpr
bool
is_inf
()
const
{
if
constexpr
(
MIGRAPHX_FP8_
FNUZ
)
if
constexpr
(
FNUZ
)
{
return
data
==
0x80
;
}
...
...
@@ -458,97 +452,139 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Lowest()
return
T
{
0xFF
,
T
::
from_bits
()};
}
using
fp8e4m3fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
>
;
// https://onnx.ai/onnx/technical/float8.html
using
fp8e4m3fn
=
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
,
false
>
;
using
fp8e5m2
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
false
>
;
using
fp8e4m3fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
,
true
>
;
using
fp8e5m2fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
true
>
;
template
<
>
class
numeric_limits
<
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
>
>
class
numeric_limits
<
fp8e4m3fnuz
>
{
public:
// TODO :figure out epsilon in Hex to make it constexpr
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
>
epsilon
()
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
epsilon
()
{
return
fp8e4m3fnuz
(
0x28
,
fp8e4m3fnuz
::
from_bits
());
}
// NOLINTNEXTLINE
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
quiet_NaN
()
{
return
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
>
(
0x28
,
migraphx
::
fp8
::
float8
<>::
from_bits
());
return
fp8e4m3fnuz
(
0x80
,
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
>
quiet_NaN
()
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
max
()
{
return
fp8e4m3fnuz
(
0x7F
,
fp8e4m3fnuz
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
min
()
{
return
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
>
(
MIGRAPHX_FP8_FNUZ
?
0x80
:
0x7F
,
migraphx
::
fp8
::
float8
<>::
from_bits
());
return
fp8e4m3fnuz
(
0x08
,
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
>
max
()
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
lowest
()
{
return
migraphx
::
fp8
::
F8_Max
<
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
>>
(
);
return
fp8e4m3fnuz
(
0xFF
,
fp8e4m3fnuz
::
from_bits
()
);
}
};
// TODO figure out Hex value
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
>
min
()
template
<
>
class
numeric_limits
<
fp8e4m3fn
>
{
public:
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fn
epsilon
()
{
return
fp8e4m3fn
(
0x20
,
fp8e4m3fn
::
from_bits
());
}
// NOLINTNEXTLINE
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fn
quiet_NaN
()
{
return
static_cast
<
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
>>
(
-
1.0
f
)
*
migraphx
::
fp8
::
F8_Max
<
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
>>
();
return
fp8e4m3fn
(
0x7F
,
fp8e4m3fn
::
from_bits
());
}
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
>
lowest
()
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fn
max
()
{
return
fp8e4m3fn
(
0x7E
,
fp8e4m3fn
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fn
min
()
{
return
migraphx
::
fp8
::
F8_Lowest
<
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
>>
(
);
return
fp8e4m3fn
(
0x08
,
fp8e4m3fn
::
from_bits
()
);
}
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
>
infinity
()
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fn
lowest
()
{
return
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
>
(
MIGRAPHX_FP8_FNUZ
?
0x80
:
0x7F
,
migraphx
::
fp8
::
float8
<>::
from_bits
());
return
fp8e4m3fn
(
0xFE
,
fp8e4m3fn
::
from_bits
());
}
};
template
<
>
class
numeric_limits
<
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>
>
class
numeric_limits
<
fp8e5m2fnuz
>
{
public:
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>
epsilon
()
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e5m2fnuz
epsilon
()
{
return
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>
(
0x34
,
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>::
from_bits
());
return
fp8e5m2fnuz
(
0x34
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>
quiet_NaN
()
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e5m2fnuz
quiet_NaN
()
// NOLINT
{
return
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>
(
MIGRAPHX_FP8_FNUZ
?
0x80
:
0x7d
,
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>::
from_bits
());
return
fp8e5m2fnuz
(
0x80
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>
max
()
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e5m2fnuz
max
()
{
return
static_cast
<
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>>
(
migraphx
::
fp8
::
F8_Max
<
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>>
());
return
fp8e5m2fnuz
(
0x7F
,
fp8e5m2fnuz
::
from_bits
());
}
// TODO figure out constexpr value
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>
min
()
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e5m2fnuz
min
()
{
return
static_cast
<
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>>
(
float
(
-
1.0
f
))
*
migraphx
::
fp8
::
F8_Max
<
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>>
();
return
fp8e5m2fnuz
(
0x4
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>
lowest
()
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e5m2fnuz
lowest
()
{
return
fp8e5m2fnuz
(
0xFF
,
fp8e5m2fnuz
::
from_bits
());
}
};
template
<
>
class
numeric_limits
<
fp8e5m2
>
{
public:
static
constexpr
bool
has_infinity
=
true
;
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e5m2
epsilon
()
{
return
migraphx
::
fp8
::
F8_Lowest
<
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>>
();
return
fp8e5m2
(
0x34
,
fp8e5m2
::
from_bits
());
}
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e5m2
quiet_NaN
()
{
return
fp8e5m2
(
0xFF
,
fp8e5m2
::
from_bits
());
}
// NOLINT
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e5m2
max
()
{
return
fp8e5m2
(
0x7B
,
fp8e5m2
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e5m2
min
()
{
return
fp8e5m2
(
0x4
,
fp8e5m2
::
from_bits
());
}
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>
infinity
()
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e5m2
lowest
()
{
return
fp8e5m2
(
0xFB
,
fp8e5m2
::
from_bits
());
}
// 7C and FC both are infinity
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e5m2
infinity
()
{
return
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>
(
MIGRAPHX_FP8_FNUZ
?
0x80
:
0x7c
,
migraphx
::
fp8
::
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
>::
from_bits
());
return
fp8e5m2
(
0x7C
,
fp8e5m2
::
from_bits
());
}
};
/*
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment