Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3fa15bcb
Commit
3fa15bcb
authored
Nov 21, 2024
by
Andriy Roshchenko
Browse files
Use `enum class` instead of `enum`
parent
97bad9f9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
67 additions
and
34 deletions
+67
-34
include/ck/utility/amd_ck_fp8.hpp
include/ck/utility/amd_ck_fp8.hpp
+67
-34
No files found.
include/ck/utility/amd_ck_fp8.hpp
View file @
3fa15bcb
...
...
@@ -32,9 +32,9 @@ using bf8_fnuz_t = unsigned _BitInt(8);
#endif
#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
#define CK_OFP8_CVT_FAST_PATH 1
#define CK_O
CP_
FP8_CVT_FAST_PATH 1
#else
#define CK_OFP8_CVT_FAST_PATH 0
#define CK_O
CP_
FP8_CVT_FAST_PATH 0
#endif
typedef
unsigned
char
fp8_storage_t
;
...
...
@@ -42,7 +42,7 @@ typedef unsigned char fp8_storage_t;
/**
* \brief Describes FP8 interpretation
*/
enum
ck_fp8_interpretation_t
enum
class
ck_fp8_interpretation_t
{
CK_E4M3_OCP
=
0
,
// OCP E4M3
CK_E5M2_OCP
=
1
,
// OCP E5M2
...
...
@@ -53,7 +53,7 @@ enum ck_fp8_interpretation_t
/**
* \brief Describes saturation behavior
*/
enum
ck_saturation_t
enum
class
ck_saturation_t
{
CK_NOSAT
=
0
,
// No saturation - replace with NaN or Inf
CK_SATFINITE
=
1
,
// Saturate to finite
...
...
@@ -250,11 +250,14 @@ static __device__ float cast_to_f32_from_f8(fp8_storage_t v)
}
val
;
val
.
i8val
[
0
]
=
v
;
static_assert
(
interpret
==
CK_E4M3_FNUZ
||
interpret
==
CK_E4M3_OCP
||
interpret
==
CK_E5M2_FNUZ
||
interpret
==
CK_E5M2_OCP
,
static_assert
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
||
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interpret
==
ck_fp8_interpretation_t
::
CK_E5M2_FNUZ
||
interpret
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only FNUZ and OCP interpretations are supported"
);
if
constexpr
((
interpret
==
CK_E4M3_FNUZ
)
||
(
interpret
==
CK_E4M3_OCP
))
if
constexpr
((
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
)
||
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
))
{
return
__builtin_amdgcn_cvt_f32_fp8
(
val
.
i32val
,
0
);
}
...
...
@@ -269,11 +272,14 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
{
const
auto
i16val
=
bit_cast
<
uint16_t
>
(
v
);
static_assert
(
interpret
==
CK_E4M3_FNUZ
||
interpret
==
CK_E4M3_OCP
||
interpret
==
CK_E5M2_FNUZ
||
interpret
==
CK_E5M2_OCP
,
static_assert
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
||
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interpret
==
ck_fp8_interpretation_t
::
CK_E5M2_FNUZ
||
interpret
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only FNUZ and OCP interpretations are supported"
);
if
constexpr
((
interpret
==
CK_E4M3_FNUZ
)
||
(
interpret
==
CK_E4M3_OCP
))
if
constexpr
((
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
)
||
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
))
{
return
__builtin_amdgcn_cvt_pk_f32_fp8
(
i16val
,
false
);
}
...
...
@@ -295,8 +301,9 @@ struct f8_ocp_t
using
data_type
=
fp8_storage_t
;
data_type
data
;
static
constexpr
ck_saturation_t
default_saturation
=
CK_SATFINITE
;
static
constexpr
ck_fp8_interpretation_t
default_interpret
=
CK_E4M3_OCP
;
static
constexpr
ck_saturation_t
default_saturation
=
ck_saturation_t
::
CK_SATFINITE
;
static
constexpr
ck_fp8_interpretation_t
default_interpret
=
ck_fp8_interpretation_t
::
CK_E4M3_OCP
;
static
constexpr
unsigned
int
we
=
4
;
// exponent width
static
constexpr
unsigned
int
wm
=
3
;
// mantissa width
...
...
@@ -312,7 +319,7 @@ struct f8_ocp_t
__host__
explicit
operator
float
()
const
#endif
{
#if CK_OFP8_CVT_FAST_PATH
#if CK_O
CP_
FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
);
#else
return
fp8_impl
::
cast_from_f8
<
float
,
wm
,
we
,
false
>
(
...
...
@@ -326,7 +333,7 @@ struct f8_ocp_t
__host__
explicit
operator
_Float16
()
const
#endif
{
#if CK_OFP8_CVT_FAST_PATH
#if CK_O
CP_
FP8_CVT_FAST_PATH
return
static_cast
<
_Float16
>
(
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
));
#else
return
fp8_impl
::
cast_from_f8
<
_Float16
,
wm
,
we
,
false
>
(
...
...
@@ -340,8 +347,9 @@ struct bf8_ocp_t
using
data_type
=
fp8_storage_t
;
data_type
data
;
static
constexpr
ck_saturation_t
default_saturation
=
CK_SATFINITE
;
static
constexpr
ck_fp8_interpretation_t
default_interpret
=
CK_E5M2_OCP
;
static
constexpr
ck_saturation_t
default_saturation
=
ck_saturation_t
::
CK_SATFINITE
;
static
constexpr
ck_fp8_interpretation_t
default_interpret
=
ck_fp8_interpretation_t
::
CK_E5M2_OCP
;
static
constexpr
unsigned
int
we
=
5
;
// exponent width
static
constexpr
unsigned
int
wm
=
2
;
// mantissa width
...
...
@@ -442,7 +450,7 @@ struct non_native_vector_base<f8_ocp_t, 2>
__host__
explicit
operator
float2_t
()
const
#endif
{
#if CK_OFP8_CVT_FAST_PATH
#if CK_O
CP_
FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32x2_from_f8x2
<
f8_ocp_t
::
default_interpret
>
(
d
);
#else
return
float2_t
{
fp8_impl
::
cast_from_f8
<
float
,
f8_ocp_t
::
wm
,
f8_ocp_t
::
we
,
false
>
(
d
[
0
]),
...
...
@@ -529,14 +537,16 @@ namespace fp8_impl {
// Assertions to check for supported conversion types
#define __assert_ocp_support(interp) \
{ \
if(interp != CK_E4M3_OCP && interp != CK_E5M2_OCP) \
if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \
interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \
{ \
__hip_assert(false && "type is unsupported by current target device"); \
} \
}
#define __assert_fnuz_support(interp) \
{ \
if(interp != CK_E4M3_FNUZ && interp != CK_E5M2_FNUZ) \
if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \
interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \
{ \
__hip_assert(false && "type is unsupported by current target device"); \
} \
...
...
@@ -574,14 +584,14 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng =
if
constexpr
(
saturate
)
{
if
constexpr
(
interpret
==
CK_E4M3_FNUZ
)
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
)
{
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
240.0
,
-
240.0
);
}
}
else
if
constexpr
(
interpret
==
CK_E4M3_OCP
)
else
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
// OCP type
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
...
...
@@ -599,7 +609,8 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng =
if
constexpr
(
stochastic_rounding
)
{
ival
=
(
interpret
==
CK_E4M3_FNUZ
)
||
(
interpret
==
CK_E4M3_OCP
)
ival
=
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
)
||
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
?
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
)
:
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
...
...
@@ -607,7 +618,8 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng =
}
else
{
// RNE CVT
ival
=
(
interpret
==
CK_E4M3_FNUZ
)
||
(
interpret
==
CK_E4M3_OCP
)
ival
=
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
)
||
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
?
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
)
:
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
...
...
@@ -897,7 +909,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
* \return fp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
ck_saturation_t
sat
=
CK_SATFINITE
,
ck_saturation_t
sat
=
ck_saturation_t
::
CK_SATFINITE
,
bool
stochastic_rounding
=
false
>
#if CK_FP8_CVT_FAST_PATH
__host__
__device__
static
inline
fp8_storage_t
cvt_float_to_fp8
(
const
float
f
)
...
...
@@ -909,7 +921,8 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
}
return
cast_to_f8_from_f32
<
interp
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
return
cast_to_f8_from_f32
<
interp
,
sat
==
ck_saturation_t
::
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
#else
#if CK_USE_OCP_FP8
__host__
__device__
static
inline
fp8_storage_t
cvt_float_to_fp8
(
const
float
f
)
...
...
@@ -925,21 +938,41 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
}
if
constexpr
(
interp
==
CK_E4M3_FNUZ
)
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
)
{
return
cast_to_f8
<
float
,
3
,
4
,
true
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
return
cast_to_f8
<
float
,
3
,
4
,
true
,
sat
==
ck_saturation_t
::
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
}
else
if
constexpr
(
interp
==
CK_E5M2_FNUZ
)
else
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_FNUZ
)
{
return
cast_to_f8
<
float
,
2
,
5
,
true
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
return
cast_to_f8
<
float
,
2
,
5
,
true
,
sat
==
ck_saturation_t
::
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
}
else
if
constexpr
(
interp
==
CK_E4M3_OCP
)
else
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
cast_to_f8
<
float
,
3
,
4
,
false
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
return
cast_to_f8
<
float
,
3
,
4
,
false
,
sat
==
ck_saturation_t
::
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
}
else
if
constexpr
(
interp
==
CK_E5M2_OCP
)
else
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
)
{
return
cast_to_f8
<
float
,
2
,
5
,
false
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
return
cast_to_f8
<
float
,
2
,
5
,
false
,
sat
==
ck_saturation_t
::
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
}
else
{
...
...
@@ -959,7 +992,7 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
* \return fp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
ck_saturation_t
sat
=
CK_SATFINITE
,
ck_saturation_t
sat
=
ck_saturation_t
::
CK_SATFINITE
,
bool
stochastic_rounding
=
false
>
#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
__host__
__device__
static
inline
fp8_storage_t
cvt_half_t_to_fp8
(
const
_Float16
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment