Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
35aebe59
Unverified
Commit
35aebe59
authored
Jan 27, 2025
by
Andriy Roshchenko
Committed by
GitHub
Jan 27, 2025
Browse files
Add OCP FP8 support in CK_TILE (#1829)
* Add OCP FP8 to CK_TILE * Validate OCP FP8 in FMHA FWD under VALID=1
parent
39dc25a9
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
621 additions
and
349 deletions
+621
-349
example/ck_tile/01_fmha/CMakeLists.txt
example/ck_tile/01_fmha/CMakeLists.txt
+5
-0
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+15
-3
include/ck_tile/core/numeric/float8.hpp
include/ck_tile/core/numeric/float8.hpp
+593
-340
include/ck_tile/core/numeric/half.hpp
include/ck_tile/core/numeric/half.hpp
+6
-5
include/ck_tile/core/numeric/numeric.hpp
include/ck_tile/core/numeric/numeric.hpp
+2
-1
No files found.
example/ck_tile/01_fmha/CMakeLists.txt
View file @
35aebe59
...
@@ -102,6 +102,11 @@ else()
...
@@ -102,6 +102,11 @@ else()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0
)
endif
()
endif
()
# conditionally specify the use of OCP_FP8
if
(
CK_USE_OCP_FP8
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8
)
endif
()
# Allow comparing floating points directly in order to check sentinel values
# Allow comparing floating points directly in order to check sentinel values
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal
)
...
...
include/ck_tile/core/config.hpp
View file @
35aebe59
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx9__
#define __gfx9__
#endif
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx94__
#define __gfx94__
#endif
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
...
@@ -230,3 +230,15 @@
...
@@ -230,3 +230,15 @@
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#endif
#endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifdef CK_TILE_USE_OCP_FP8
#define CK_TILE_USE_OCP_FP8 1
#else
#define CK_TILE_USE_OCP_FP8 0
#endif
#elif defined(__gfx950__) || defined(__gfx12__) // for GPU code
#define CK_TILE_USE_OCP_FP8 1
#else // for GPU code
#define CK_TILE_USE_OCP_FP8 0
#endif
include/ck_tile/core/numeric/float8.hpp
View file @
35aebe59
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
...
@@ -14,6 +14,12 @@
...
@@ -14,6 +14,12 @@
#pragma once
#pragma once
#if(defined(__gfx94__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
#define CK_TILE_FP8_CVT_DEVICE 1
#else
#define CK_TILE_FP8_CVT_DEVICE 0
#endif
namespace
ck_tile
{
namespace
ck_tile
{
// fp8 rounding modes
// fp8 rounding modes
...
@@ -25,15 +31,26 @@ enum class fp8_rounding_mode
...
@@ -25,15 +31,26 @@ enum class fp8_rounding_mode
stochastic
stochastic
};
};
/**
* \brief FP8 interpretation used in conversion algorithms
*/
enum
class
fp8_interpretation
{
E4M3_OCP
=
0
,
// OCP FP8 E4M3
E5M2_OCP
=
1
,
// OCP BF8 E5M2
E4M3_FNUZ
=
2
,
// FNUZ FP8 E4M3
E5M2_FNUZ
=
3
,
// FNUZ BF8 E5M2
};
/*
/*
* ______________
NANOO
_________________ | ______________
IEEE
________________
* ______________
FNUZ
_________________ | ______________
OCP
________________
* e4m3 e5m2 | e4m3 e5m2
* e4m3 e5m2 | e4m3 e5m2
* bias : 8 16 | 7 15
* bias : 8 16 | 7 15
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111
(448)
s.00000.11
(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111
s.00000.11
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
...
@@ -55,10 +72,10 @@ struct alignas(1) float8_e4m3_t
...
@@ -55,10 +72,10 @@ struct alignas(1) float8_e4m3_t
{
{
static
constexpr
int
exponent
=
4
;
static
constexpr
int
exponent
=
4
;
static
constexpr
int
mantissa
=
3
;
static
constexpr
int
mantissa
=
3
;
#if
defined(__gfx94__)
#if
CK_TILE_USE_OCP_FP8
static
constexpr
int
bias
=
1
<<
(
exponent
-
1
);
// NANOO
static
constexpr
int
bias
=
7
;
// OCP
#else
#else
static
constexpr
int
bias
=
(
1
<<
(
exponent
-
1
))
-
1
;
// IEEE
static
constexpr
int
bias
=
8
;
// FNUZ
#endif
#endif
using
raw_type
=
uint8_t
;
using
raw_type
=
uint8_t
;
raw_type
data
;
raw_type
data
;
...
@@ -113,10 +130,10 @@ struct alignas(1) float8_e5m2_t
...
@@ -113,10 +130,10 @@ struct alignas(1) float8_e5m2_t
{
{
static
constexpr
int
exponent
=
5
;
static
constexpr
int
exponent
=
5
;
static
constexpr
int
mantissa
=
2
;
static
constexpr
int
mantissa
=
2
;
#if
defined(__gfx94__)
#if
CK_TILE_USE_OCP_FP8
static
constexpr
int
bias
=
1
<<
(
exponent
-
1
);
// NANOO
static
constexpr
int
bias
=
1
5
;
// OCP
#else
#else
static
constexpr
int
bias
=
(
1
<<
(
exponent
-
1
))
-
1
;
//
IEEE
static
constexpr
int
bias
=
1
6
;
//
FNUZ
#endif
#endif
using
raw_type
=
uint8_t
;
using
raw_type
=
uint8_t
;
raw_type
data
;
raw_type
data
;
...
@@ -183,501 +200,727 @@ struct native_t<bf8_t>
...
@@ -183,501 +200,727 @@ struct native_t<bf8_t>
};
};
#else
#else
using
fp8_t
=
_BitInt
(
8
);
using
fp8_t
=
_BitInt
(
8
);
using
fp8_raw_t
=
uint8_t
;
using
fp8_raw_t
=
uint8_t
;
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
bf8_raw_t
=
uint8_t
;
using
bf8_raw_t
=
uint8_t
;
#endif
#endif
// below is sw fp8 conversion, not utilizing hw instruction
template
<
typename
T
>
namespace
impl
{
struct
numeric_traits
;
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
template
<
>
CK_TILE_HOST_DEVICE
Y
run_cast_to_f8
(
X
x
,
uint32_t
rng
)
struct
numeric_traits
<
fp8_t
>
{
{
// fp8/bf8 exponent/mantissa layout
using
bitwise_type
=
fp8_raw_t
;
constexpr
int
out_exp
=
numeric_traits
<
Y
>::
exp
;
constexpr
int
out_mant
=
numeric_traits
<
Y
>::
mant
;
// original type exponent/mantissa layout
static
constexpr
int
exp
=
4
;
constexpr
int
in_exp
=
numeric_traits
<
X
>::
exp
;
static
constexpr
int
mant
=
3
;
constexpr
int
in_mant
=
numeric_traits
<
X
>::
mant
;
#if CK_TILE_USE_OCP_FP8
static
constexpr
int
bias
=
7
;
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E4M3_OCP
;
#else
static
constexpr
int
bias
=
8
;
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E4M3_FNUZ
;
#endif
static
constexpr
uint8_t
abs_mask
=
0x7F
;
};
int
exponent
,
bias
;
template
<
>
uint32_t
head
,
mantissa
,
sign
;
struct
numeric_traits
<
bf8_t
>
// nan code is same for float and half
{
#if CK_TILE_USE_CUSTOM_DATA_TYPE
using
bitwise_type
=
bf8_raw_t
;
constexpr
Y
nan_code
=
numeric
<
Y
>::
quiet_NaN
();
// __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
#if CK_TILE_USE_OCP_FP8
static
constexpr
int
bias
=
15
;
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E5M2_OCP
;
#else
#else
constexpr
Y
nan_code
=
0x80
;
static
constexpr
int
bias
=
16
;
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E5M2_FNUZ
;
#endif
#endif
static
constexpr
uint8_t
abs_mask
=
0x7F
;
};
constexpr
uint32_t
nan_mask
=
numeric_traits
<
X
>::
nan_mask
;
// below is sw fp8 conversion, not utilizing hw instruction
namespace
impl
{
template
<
typename
SrcT
,
typename
DstT
,
bool
clip
=
true
,
bool
stoch
=
false
>
CK_TILE_HOST_DEVICE
DstT
run_cast_to_f8
(
SrcT
src
,
unsigned
int
rng
=
0
)
{
static_assert
(
std
::
is_same
<
DstT
,
fp8_t
>::
value
||
std
::
is_same
<
DstT
,
bf8_t
>::
value
,
"DstT type must be fp8 or bf8."
);
// convert to bitwise
constexpr
bool
is_half
=
std
::
is_same
<
SrcT
,
half_t
>::
value
;
using
T_bitwise
=
typename
numeric_traits
<
X
>::
bitwise_typ
e
;
constexpr
bool
is_float
=
std
::
is_same
<
SrcT
,
float
>::
valu
e
;
T_bitwise
x_bitwise
=
*
(
reinterpret_cast
<
T_bitwise
*>
(
&
x
)
);
static_assert
(
is_half
||
is_float
,
"Only half and float can be cast to f8"
);
//
unpack the input, depends on datatype
//
fp8/bf8 type exponent/mantissa layout
head
=
x_bitwise
&
numeric_traits
<
X
>::
head_mask
;
constexpr
int
DstT_exp
=
numeric_traits
<
DstT
>::
exp
;
// exponent width of the destination type
mantissa
=
x_bitwise
&
numeric_traits
<
X
>::
mant
_mask
;
constexpr
int
DstT_mant
=
numeric_traits
<
DstT
>::
mant
;
// mantissa width of the destination type
exponent
=
(
head
>>
in_mant
)
&
numeric_traits
<
X
>::
exp_mask
;
constexpr
bool
is_fnuz
=
sign
=
head
>>
(
in_exp
+
in_mant
);
(
numeric_traits
<
DstT
>::
f8_interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
||
bias
=
numeric_traits
<
X
>::
bias
;
(
numeric_traits
<
DstT
>::
f8_interpret
==
fp8_interpretation
::
E5M2_FNUZ
)
;
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
constexpr
int
SrcT_exp
=
numeric_traits
<
SrcT
>::
exp
;
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
constexpr
int
SrcT_mant
=
numeric_traits
<
SrcT
>::
mant
;
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
if
constexpr
(
negative_zero_nan
)
using
SrcT_bitwise
=
typename
numeric_traits
<
SrcT
>::
bitwise_type
;
SrcT_bitwise
src_bitwise
=
bit_cast
<
SrcT_bitwise
>
(
src
);
unsigned
long
long
head
,
mantissa
;
int
exponent
,
bias
;
unsigned
int
sign
;
unsigned
long
long
fInf
,
abs_mask
;
head
=
src_bitwise
&
numeric_traits
<
SrcT
>::
head_mask
;
mantissa
=
src_bitwise
&
numeric_traits
<
SrcT
>::
mant_mask
;
exponent
=
(
head
>>
SrcT_mant
)
&
numeric_traits
<
SrcT
>::
exp_mask
;
sign
=
head
>>
(
SrcT_exp
+
SrcT_mant
);
bias
=
numeric_traits
<
SrcT
>::
bias
;
fInf
=
numeric_traits
<
SrcT
>::
Inf
;
abs_mask
=
numeric_traits
<
SrcT
>::
abs_mask
;
unsigned
int
signed_inf
=
0
;
unsigned
int
nan
=
0
;
if
constexpr
(
is_fnuz
)
{
{
if
((
x_bitwise
&
nan_mask
)
==
nan_mask
)
signed_inf
=
clip
?
((
sign
<<
7
)
+
0x7f
)
:
0x80
;
return
nan_code
;
nan
=
0x80
;
}
}
else
else
{
{
if
((
x_bitwise
&
nan_mask
)
==
nan_mask
)
if
constexpr
(
DstT_exp
==
4
)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
{
// e4m3
signed_inf
=
(
sign
<<
7
)
+
(
clip
?
0x7e
:
0x7f
);
}
else
{
// e5m2
signed_inf
=
(
sign
<<
7
)
+
(
clip
?
0x7b
:
0x7c
);
}
nan
=
(
sign
<<
7
)
+
0x7f
;
}
// Max values
unsigned
long
long
ifmax
=
0
;
if
constexpr
(
is_float
)
{
if
constexpr
(
DstT_exp
==
5
)
{
ifmax
=
0x47600000
;
}
else
{
if
constexpr
(
is_fnuz
)
{
ifmax
=
0x43700000
;
}
else
{
ifmax
=
0x43E00000
;
}
}
}
else
if
constexpr
(
is_half
)
{
if
constexpr
(
DstT_exp
==
5
)
{
ifmax
=
0x7B00
;
}
else
{
if
constexpr
(
is_fnuz
)
{
ifmax
=
0x5B80
;
}
else
{
ifmax
=
0x5F00
;
}
}
}
}
// check if x is 0.0
// Deal with inf and NaNs
if
(
x_bitwise
==
0
)
if
((
src_bitwise
&
fInf
)
==
fInf
)
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
(
0
));
{
if
constexpr
(
is_fnuz
)
return
signed_inf
;
// First need to check if it is normal or denorm as there is a difference of implict 1
return
mantissa
!=
0
?
nan
:
signed_inf
;
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
}
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again3
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
if
((
src_bitwise
&
abs_mask
)
>
ifmax
)
const
int
out_bias
=
(
1
<<
(
out_exp
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
{
const
int
out_denormal_act_exponent
=
1
-
out_bias
;
// actual exponent of f8 denormal
return
signed_inf
;
}
if
(
src_bitwise
==
0
)
{
return
0
;
}
// First need to check if it is normal or denorm as there is a difference of
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
// need to check whether there is carry and adjust exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const
int
f8_bias
=
(
1
<<
(
DstT_exp
-
1
))
-
1
+
(
is_fnuz
?
1
:
0
);
const
int
f8_denormal_act_exponent
=
1
-
f8_bias
;
// actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
//
out
_exponent is the converted f8 exponent with bias encoding
//
f8
_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
out
_exponent
,
exponent_diff
;
int
act_exponent
,
f8
_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
exponent bias 16. It means that there are some numbers in fp16 denormal but they
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
In this case, the fp16 mantissa should be shift left by 1 */
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
out
_denormal_act_exponent
-
exponent_diff
=
f8
_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
}
else
else
{
// fp32/fp16 is normal with implicit 1
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
out
_denormal_act_exponent
)
if
(
act_exponent
<=
f8
_denormal_act_exponent
)
{
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range.
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
actual exponent is -7, it is actually larger due to the implic
i
t 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
out
_denormal_act_exponent
-
act_exponent
;
exponent_diff
=
f8
_denormal_act_exponent
-
act_exponent
;
}
}
else
else
{
// both fp32/fp16 and f8 are in normal range
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference
0
;
// exponent_diff=0 does not mean there is no difference for this case,
// for this case, act_exponent could be larger. Just
// act_exponent could be larger. Just
that it does not need shift mantissa
//
that it does not need shift mantissa
}
}
mantissa
+=
(
1
<<
in
_mant
);
// Add the implicit 1 into mantissa
mantissa
+=
(
1
ull
<<
SrcT
_mant
);
// Add the implicit 1 into mantissa
}
}
bool
midpoint
=
(
mantissa
&
((
1
<<
(
in_mant
-
out_mant
+
exponent_diff
))
-
1
))
==
bool
midpoint
=
(
mantissa
&
((
1ull
<<
(
SrcT_mant
-
DstT_mant
+
exponent_diff
))
-
1
))
==
(
1
<<
(
in_mant
-
out_mant
+
exponent_diff
-
1
));
(
1ull
<<
(
SrcT_mant
-
DstT_mant
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
shift right as shift right could rip off some residual part and make something not midpoint look
done before we shift right as shift right could rip off some residual part and
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
make something not midpoint look like midpoint. For example, the fp16 number
midpoint, but after shift right by 4 bits, it would look like midpoint. */
0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
by 4 bits, it would look like midpoint.
*/
if
(
exponent_diff
>
0
)
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
in_mant
);
bool
implicit_one
=
mantissa
&
(
1ull
<<
SrcT_mant
);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
// if there is no implicit 1, it means the f8 is denormal and need to adjust
out_exponent
=
// to denorm exponent
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
out_bias
-
(
implicit_one
?
0
:
1
);
f8_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
f8_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
// Now we have the exponent and mantissa adjusted
unsigned
long
long
drop_mask
=
(
1ull
<<
(
SrcT_mant
-
DstT_mant
))
-
1
;
bool
odd
=
bool
odd
=
mantissa
&
mantissa
&
(
1ull
<<
(
SrcT_mant
-
(
1
<<
(
in_mant
-
out_mant
));
// if the least significant bit that is not truncated is 1
DstT_mant
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1ull
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
// Now we deal with overflow
if
(
out
_exponent
==
0
)
if
(
f8
_exponent
==
0
)
{
{
if
((
1
<<
in
_mant
)
&
mantissa
)
if
((
1
ull
<<
SrcT
_mant
)
&
mantissa
)
{
{
out_exponent
=
1
;
// denormal overflow to become normal, promote exponent
f8_exponent
=
1
;
// denormal overflow to become normal, promote exponent
// No need to make 1 implicit now as it will be addressed later
}
}
}
}
else
else
{
{
if
((
1
<<
(
in
_mant
+
1
))
&
mantissa
)
if
((
1
ull
<<
(
SrcT
_mant
+
1
))
&
mantissa
)
{
{
mantissa
>>=
1
;
mantissa
>>=
1
;
out_exponent
++
;
f8_exponent
++
;
// No need to make 1 implicit now as it will be addressed later
}
}
}
}
mantissa
>>=
(
in
_mant
-
out
_mant
);
mantissa
>>=
(
SrcT
_mant
-
DstT
_mant
);
if
(
out_exponent
>
max_exp
)
// above range: quantize to maximum possible float of the same sign
const
int
max_exp
=
(
1
<<
DstT_exp
)
-
1
;
if
(
f8_exponent
>
max_exp
)
{
{
if
(
clip
)
if
constexpr
(
clip
)
{
{
mantissa
=
(
1
<<
out
_mant
)
-
1
;
mantissa
=
(
1
<<
DstT
_mant
)
-
1
;
out
_exponent
=
max_exp
;
f8
_exponent
=
max_exp
;
}
}
else
else
{
{
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
(
signed_inf
))
;
return
signed_inf
;
}
}
}
}
// check if x is 0.0 or -0.0
if
(
f8_exponent
==
0
&&
mantissa
==
0
)
if
(
out_exponent
==
0
&&
mantissa
==
0
)
return
is_fnuz
?
0
:
(
sign
<<
7
);
return
__builtin_bit_cast
(
mantissa
&=
(
1
<<
DstT_mant
)
-
1
;
Y
,
static_cast
<
uint8_t
>
(
negative_zero_nan
?
0
:
(
sign
<<
(
out_exp
+
out_mant
))));
return
(
sign
<<
7
)
|
(
f8_exponent
<<
DstT_mant
)
|
mantissa
;
mantissa
&=
(
1
<<
out_mant
)
-
1
;
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
((
sign
<<
(
out_exp
+
out_mant
))
|
(
out_exponent
<<
out_mant
)
|
mantissa
));
}
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
template
<
typename
SrcT
,
typename
DstT
,
bool
clip
=
true
>
CK_TILE_HOST_DEVICE
Y
run_cast_from_f8
(
X
x
)
CK_TILE_HOST_DEVICE
DstT
run_cast_from_f8
(
SrcT
x
)
{
{
// fp8/bf8 exponent/mantissa layout
static_assert
(
std
::
is_same
<
SrcT
,
fp8_t
>::
value
||
std
::
is_same
<
SrcT
,
bf8_t
>::
value
,
constexpr
int
in_exp
=
numeric_traits
<
X
>::
exp
;
"SrcT type must be fp8 or bf8."
);
constexpr
int
in_mant
=
numeric_traits
<
X
>::
mant
;
constexpr
int
SrcT_exp
=
numeric_traits
<
SrcT
>::
exp
;
constexpr
int
SrcT_mant
=
numeric_traits
<
SrcT
>::
mant
;
// resulting type exponent/mantissa layout
constexpr
bool
is_fnuz
=
constexpr
int
out_exp
=
numeric_traits
<
Y
>::
exp
;
(
numeric_traits
<
SrcT
>::
f8_interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
||
constexpr
int
out_mant
=
numeric_traits
<
Y
>::
mant
;
(
numeric_traits
<
SrcT
>::
f8_interpret
==
fp8_interpretation
::
E5M2_FNUZ
);
uint8_t
x_raw
=
__builtin_bit_cast
(
uint8_t
,
x
);
constexpr
bool
is_half
=
std
::
is_same
<
DstT
,
half_t
>::
value
;
// prepare the codes
constexpr
bool
is_float
=
std
::
is_same
<
DstT
,
float
>::
value
;
constexpr
uint8_t
nan_code
=
0x80
;
static_assert
(
is_half
||
is_float
,
"DstT type must be half_t or float."
);
Y
Inf
,
NegInf
,
NaN
,
Neg0
;
using
T_bitwise
=
typename
numeric_traits
<
Y
>::
bitwise_type
;
// destination type exponent/mantissa layout
constexpr
int
DstT_exp
=
numeric_traits
<
DstT
>::
exp
;
// exponent width of the destination type
constexpr
T_bitwise
Inf_bitwise
=
numeric_traits
<
Y
>::
Inf
;
constexpr
int
DstT_mant
=
numeric_traits
<
DstT
>::
mant
;
// mantissa width of the destination type
constexpr
T_bitwise
NegInf_bitwise
=
numeric_traits
<
Y
>::
NegInf
;
constexpr
T_bitwise
NaN_bitwise
=
numeric_traits
<
Y
>::
NaN
;
constexpr
DstT
fInf
=
bit_cast
<
DstT
>
(
numeric_traits
<
DstT
>::
Inf
);
constexpr
T_bitwise
Neg0_bitwise
=
numeric_traits
<
Y
>::
Neg0
;
constexpr
DstT
fNegInf
=
bit_cast
<
DstT
>
(
numeric_traits
<
DstT
>::
NegInf
);
constexpr
DstT
fNaN
=
bit_cast
<
DstT
>
(
numeric_traits
<
DstT
>::
NaN
);
Inf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Inf_bitwise
));
constexpr
DstT
fNeg0
=
bit_cast
<
DstT
>
(
numeric_traits
<
DstT
>::
Neg0
);
NegInf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NegInf_bitwise
));
NaN
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NaN_bitwise
));
DstT
fmax
{
0
},
fmin
{
0
};
Neg0
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Neg0_bitwise
));
// Max number in e5m2 57344
if
constexpr
(
is_half
)
// check if x is 0.0
{
if
(
x_raw
==
0
)
fmax
=
bit_cast
<
DstT
>
(
static_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
0x7B00
));
return
static_cast
<
Y
>
(
0
);
fmin
=
bit_cast
<
DstT
>
(
static_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
0xFB00
));
}
// unpack the input
else
if
constexpr
(
is_float
)
uint32_t
sign
=
x_raw
>>
(
in_exp
+
in_mant
);
{
uint32_t
mantissa
=
x_raw
&
((
1
<<
in_mant
)
-
1
);
fmax
=
bit_cast
<
DstT
>
(
static_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
0x47600000
));
int
exponent
=
(
x_raw
&
0x7F
)
>>
in_mant
;
fmin
=
bit_cast
<
DstT
>
(
static_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
0xC7600000
));
}
constexpr
int
exp_low_cutoff
=
if
(
x
==
0
)
(
1
<<
(
out_exp
-
1
))
-
(
1
<<
(
in_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
{
T_bitwise
retval
;
return
0
;
}
if
constexpr
(
negative_zero_nan
)
unsigned
long
long
sign
=
x
>>
7
;
unsigned
long
long
mantissa
=
x
&
((
1
<<
SrcT_mant
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
SrcT_mant
;
if
constexpr
(
is_fnuz
)
{
if
(
x
==
0x80
)
{
{
if
(
x_raw
==
nan_code
)
return
fNaN
;
return
NaN
;
}
}
}
else
else
{
{
if
(
x_raw
==
nan_code
)
if
(
x
==
0x80
)
return
Neg0
;
{
if
(
exponent
==
((
1
<<
in_exp
)
-
1
))
return
fNeg0
;
return
(
mantissa
==
0
)
?
(
sign
?
NegInf
:
Inf
)
:
NaN
;
}
if
constexpr
(
SrcT_exp
==
4
)
{
// e4m3
if
((
x
&
0x7F
)
==
0x7F
)
{
return
fNaN
;
}
}
else
if
((
x
&
0x7C
)
==
0x7C
)
{
// e5m2
if
((
x
&
0x3
)
==
0
)
{
if
constexpr
(
clip
)
{
return
sign
?
fmin
:
fmax
;
}
return
sign
?
fNegInf
:
fInf
;
}
return
fNaN
;
}
}
}
typename
numeric_traits
<
DstT
>::
bitwise_type
retval
;
if
((
numeric_traits
<
Y
>::
mant
==
10
)
&&
(
numeric_traits
<
X
>::
mant
==
2
)
&&
!
negative_zero_nan
)
if
constexpr
(
SrcT_exp
==
5
&&
is_half
&&
!
is_fnuz
)
{
{
retval
=
x_raw
;
retval
=
x
<<
8
;
retval
<<=
8
;
return
bit_cast
<
DstT
>
(
retval
);
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
}
const
int
exp_low_cutoff
=
(
1
<<
(
DstT_exp
-
1
))
-
(
1
<<
(
SrcT_exp
-
1
))
+
1
-
(
is_fnuz
?
1
:
0
);
// subnormal input
// subnormal input
if
(
exponent
==
0
)
if
(
exponent
==
0
)
{
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
SrcT_mant
);
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
in_mant
);
mantissa
<<=
sh
;
mantissa
<<=
sh
;
exponent
+=
1
-
sh
;
exponent
+=
1
-
sh
;
mantissa
&=
((
1
<<
in
_mant
)
-
1
);
mantissa
&=
((
1
ull
<<
SrcT
_mant
)
-
1
);
}
}
exponent
+=
exp_low_cutoff
-
1
;
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
out
_mant
-
in
_mant
;
mantissa
<<=
DstT
_mant
-
SrcT
_mant
;
// subnormal output (occurs when
T=
half, we=5,
negative_zero_nan
=true)
// subnormal output (occurs when
DstT is
half
_t
, we=5,
is_fnuz
=true)
if
(
exponent
<=
0
)
if
(
exponent
<=
0
)
{
{
mantissa
|=
1
<<
out
_mant
;
mantissa
|=
1
<<
DstT
_mant
;
mantissa
>>=
1
-
exponent
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
exponent
=
0
;
}
}
retval
=
(
sign
<<
(
out_exp
+
out_mant
))
|
(
exponent
<<
out_mant
)
|
mantissa
;
retval
=
(
sign
<<
(
DstT_exp
+
DstT_mant
))
|
(
exponent
<<
DstT_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
CK_TILE_HOST_DEVICE
Y
cast_to_f8
(
X
x
,
uint32_t
rng
)
{
// check datatypes
constexpr
bool
is_half
=
std
::
is_same
<
X
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
X
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted."
);
return
run
_cast
_to_f8
<
X
,
Y
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
return
bit
_cast
<
DstT
>
(
retval
);
}
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
template
<
typename
X
,
typename
Y
,
bool
clip
,
bool
stoch
>
CK_TILE_HOST_DEVICE
Y
cast_
from
_f8
(
X
x
)
CK_TILE_HOST_DEVICE
Y
cast_
to
_f8
(
X
x
,
uint32_t
rng
)
{
{
// check datatype
return
bit_cast
<
Y
>
(
run_cast_to_f8
<
X
,
Y
,
clip
,
stoch
>
(
x
,
rng
));
constexpr
bool
is_half
=
std
::
is_same
<
Y
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
Y
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported."
);
return
run_cast_from_f8
<
X
,
Y
,
negative_zero_nan
>
(
x
);
}
}
}
// namespace impl
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_sr_raw
(
float
x
)
#if CK_TILE_FP8_CVT_DEVICE
/**
* @brief Cast float to fp8/bf8 using device conversion instructions
*/
template
<
fp8_interpretation
interpret
,
bool
saturate
,
bool
stochastic_rounding
=
false
>
CK_TILE_DEVICE
uint8_t
cast_to_f8_from_f32
(
float
v
,
unsigned
int
rng
=
0
)
{
{
constexpr
int
seed
=
42
;
uint8_t
i8data
;
uint32_t
rng
=
prand_generator_t
<
float
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
union
{
{
float
fval
;
float
fval
;
u
int32_
t
i32val
;
u
nsigned
in
t
i32val
;
u
int8_t
i8val
[
4
];
// not endian independent
u
nsigned
char
i8val
[
4
];
//
NOTE:
not endian independent
}
val
;
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
unsigned
int
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
fval
=
v
;
if
constexpr
(
saturate
)
{
if
constexpr
(
interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
{
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
240.0
,
-
240.0
);
}
}
else
if
constexpr
(
interpret
==
fp8_interpretation
::
E4M3_OCP
)
{
// OCP type
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
448.0
,
-
448.0
);
}
}
else
{
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
57344.0
,
-
57344.0
);
}
}
}
if
constexpr
(
stochastic_rounding
)
{
ival
=
(
interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
||
(
interpret
==
fp8_interpretation
::
E4M3_OCP
)
?
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
)
:
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
i8data
=
val
.
i8val
[
0
];
// little endian
#else
}
constexpr
bool
negative_zero_nan
=
true
;
else
constexpr
bool
clip
=
true
;
{
// RNE CVT
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
stochastic
;
ival
=
(
interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
||
return
bit_cast
<
fp8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
(
interpret
==
fp8_interpretation
::
E4M3_OCP
)
fp8_t
,
?
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
)
negative_zero_nan
,
:
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
clip
,
val
.
fval
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
ival
,
#endif
false
);
// false -> WORD0
val
.
i32val
=
ival
;
i8data
=
val
.
i8val
[
0
];
}
return
i8data
;
}
}
#endif // CK_TILE_FP8_CVT_DEVICE
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_sr_raw
(
float
x
)
}
// namespace impl
/**
* @brief Converts a floating-point value to an 8-bit floating-point representation with stochastic
* rounding.
*
* This function converts a floating-point value (float or half_t) to an 8-bit floating-point
* representation of type fp8_t or bf8_t. The conversion process may
* involve clipping and uses a pseudo-random number generator for the stochastic rounding.
*
* @tparam DstT The destination type (fp8_t or bf8_t).
* @tparam SrcT The source type (float or half_t) to be converted.
* @param x The floating-point value to be converted.
* @return The 8-bit floating-point representation of the input value.
*/
template
<
typename
SrcT
,
typename
DstT
>
CK_TILE_HOST_DEVICE
typename
numeric_traits
<
DstT
>::
bitwise_type
float_to_fp8_sr_raw
(
SrcT
x
)
{
{
constexpr
bool
clip
=
true
;
constexpr
int
seed
=
42
;
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator_t
<
float
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator_t
<
SrcT
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
#if CK_TILE_FP8_CVT_DEVICE
union
return
impl
::
cast_to_f8_from_f32
<
numeric_traits
<
DstT
>::
f8_interpret
,
clip
,
true
>
(
x
,
rng
);
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
return
bit_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
constexpr
bool
clip
=
true
;
impl
::
cast_to_f8
<
SrcT
,
DstT
,
clip
,
true
>
(
x
,
rng
));
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
stochastic
;
return
bit_cast
<
bf8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
#endif
}
}
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_rtn_raw
(
float
x
)
/**
* @brief Converts a floating-point value to an 8-bit floating-point representation with rounding to
* nearest even.
*
* This function converts a floating-point value (float or half_t) to an 8-bit floating-point
* representation of type fp8_t or bf8_t. The conversion process may involve clipping.
*
* @tparam DstT The destination type (fp8_t or bf8_t).
* @tparam SrcT The source type (float or half_t) to be converted.
* @param x The floating-point value to be converted.
* @return The 8-bit floating-point representation of the input value.
*/
template
<
typename
SrcT
,
typename
DstT
>
CK_TILE_HOST_DEVICE
typename
numeric_traits
<
DstT
>::
bitwise_type
float_to_fp8_rtn_raw
(
SrcT
x
)
{
{
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
standard
;
#if CK_TILE_FP8_CVT_DEVICE
constexpr
uint32_t
rng
=
0
;
return
impl
::
cast_to_f8_from_f32
<
numeric_traits
<
DstT
>::
f8_interpret
,
clip
,
false
>
(
x
,
0
);
return
bit_cast
<
fp8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
fp8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
}
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_rtn_raw
(
float
x
)
{
#if defined(__gfx94__)
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
return
bit_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
constexpr
bool
clip
=
true
;
impl
::
cast_to_f8
<
SrcT
,
DstT
,
clip
,
false
>
(
x
,
0
));
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
bit_cast
<
bf8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
#endif
}
}
// clang-format off
template
<
fp8_rounding_mode
rounding
>
template
<
fp8_rounding_mode
rounding
>
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_raw
(
float
x
,
constant
<
rounding
>
)
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_raw
(
float
x
,
constant
<
rounding
>
)
{
{
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
return
float_to_fp8_rtn_raw
(
x
);
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
return
float_to_fp8_sr_raw
(
x
);
{
else
return
fp8_raw_t
{
0
};
return
float_to_fp8_rtn_raw
<
float
,
fp8_t
>
(
x
);
}
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
{
return
float_to_fp8_sr_raw
<
float
,
fp8_t
>
(
x
);
}
else
{
return
fp8_raw_t
{
0
};
}
}
}
template
<
fp8_rounding_mode
rounding
>
template
<
fp8_rounding_mode
rounding
>
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_raw
(
float
x
,
constant
<
rounding
>
)
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_raw
(
float
x
,
constant
<
rounding
>
)
{
{
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
return
float_to_bf8_rtn_raw
(
x
);
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
return
float_to_bf8_sr_raw
(
x
);
{
else
return
bf8_raw_t
{
0
};
return
float_to_fp8_rtn_raw
<
float
,
bf8_t
>
(
x
);
}
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
{
return
float_to_fp8_sr_raw
<
float
,
bf8_t
>
(
x
);
}
else
{
return
bf8_raw_t
{
0
};
}
}
}
CK_TILE_HOST_DEVICE
float
fp8_to_float_raw
(
fp8_raw_t
x
)
CK_TILE_HOST_DEVICE
float
fp8_to_float_raw
(
fp8_raw_t
x
)
{
{
#if
defined(__gfx94__)
#if
CK_TILE_FP8_CVT_DEVICE
float
fval
;
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_fp8
(
i32val
,
0
);
fval
=
__builtin_amdgcn_cvt_f32_fp8
(
i32val
,
0
);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return
fval
;
return
fval
;
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
return
impl
::
run_cast_from_f8
<
fp8_t
,
float
>
(
bit_cast
<
fp8_t
>
(
x
));
return
impl
::
cast_from_f8
<
fp8_t
,
float
,
negative_zero_nan
>
(
bit_cast
<
fp8_t
>
(
x
));
#endif
#endif
}
}
CK_TILE_HOST_DEVICE
float
bf8_to_float_raw
(
bf8_raw_t
x
)
CK_TILE_HOST_DEVICE
float
bf8_to_float_raw
(
bf8_raw_t
x
)
{
{
#if
defined(__gfx94__)
#if
CK_TILE_FP8_CVT_DEVICE
float
fval
;
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_bf8
(
i32val
,
0
);
fval
=
__builtin_amdgcn_cvt_f32_bf8
(
i32val
,
0
);
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return
fval
;
return
fval
;
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
return
impl
::
run_cast_from_f8
<
bf8_t
,
float
>
(
bit_cast
<
bf8_t
>
(
x
));
return
impl
::
cast_from_f8
<
bf8_t
,
float
,
negative_zero_nan
>
(
bit_cast
<
bf8_t
>
(
x
));
#endif
#endif
}
}
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
CK_TILE_HOST_DEVICE
fp8_t
float_to_fp8
(
float
x
,
constant
<
rounding
>
=
{})
CK_TILE_HOST_DEVICE
fp8_t
float_to_fp8
(
float
x
,
constant
<
rounding
>
=
{})
{
{
return
bit_cast
<
fp8_t
>
(
float_to_fp8_raw
(
x
,
constant
<
rounding
>
{}));
return
bit_cast
<
fp8_t
>
(
float_to_fp8_raw
(
x
,
constant
<
rounding
>
{}));
}
}
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
CK_TILE_HOST_DEVICE
bf8_t
float_to_bf8
(
float
x
,
constant
<
rounding
>
=
{})
CK_TILE_HOST_DEVICE
bf8_t
float_to_bf8
(
float
x
,
constant
<
rounding
>
=
{})
{
{
return
bit_cast
<
bf8_t
>
(
float_to_bf8_raw
(
x
,
constant
<
rounding
>
{}));
return
bit_cast
<
bf8_t
>
(
float_to_bf8_raw
(
x
,
constant
<
rounding
>
{}));
}
}
CK_TILE_HOST_DEVICE
float
fp8_to_float
(
fp8_t
x
)
CK_TILE_HOST_DEVICE
float
fp8_to_float
(
fp8_t
x
)
{
return
fp8_to_float_raw
(
bit_cast
<
fp8_raw_t
>
(
x
));
}
{
return
fp8_to_float_raw
(
bit_cast
<
fp8_raw_t
>
(
x
));
}
CK_TILE_HOST_DEVICE
float
bf8_to_float
(
bf8_t
x
)
CK_TILE_HOST_DEVICE
float
bf8_to_float
(
bf8_t
x
)
{
return
bf8_to_float_raw
(
bit_cast
<
bf8_raw_t
>
(
x
));
}
{
return
bf8_to_float_raw
(
bit_cast
<
bf8_raw_t
>
(
x
));
}
// clang-format on
template
<
class
T
>
struct
numeric
;
template
<
typename
T
>
struct
numeric_traits
;
#if CK_TILE_USE_OCP_FP8
template
<
>
template
<
>
struct
numeric
_traits
<
fp8_t
>
struct
numeric
<
fp8_t
>
{
{
static
constexpr
int
exp
=
4
;
// minimum finite value, or minimum positive normal value
static
constexpr
int
mant
=
3
;
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
min
()
#if defined(__gfx94__)
{
static
constexpr
int
bias
=
8
;
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x08
));
// 0b00001000 = 2^-6
#else
}
static
constexpr
int
bias
=
7
;
#endif
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
lowest
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0xfe
));
// 0b11111110 = -448
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
max
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x7e
));
// 0b01111110 = 448
}
// difference between 1.0 and next representable f8 value (1.125)
// returns fp8_t(0.125)
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
epsilon
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x20
));
// 0.125
}
// rounding error (0.0625)
// half of epsilon
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
round_error
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x18
));
// 0.0625
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
quiet_NaN
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x7F
));
// 0b01111111
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
signaling_NaN
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0xFF
));
// 0b11111111
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
denorm_min
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x01
));
}
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
zero
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0
));
}
};
};
template
<
>
template
<
>
struct
numeric
_traits
<
bf8_t
>
struct
numeric
<
bf8_t
>
{
{
static
constexpr
int
exp
=
5
;
// minimum finite value, or minimum positive normalized value for float
static
constexpr
int
mant
=
2
;
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
min
()
#if defined(__gfx94__)
{
static
constexpr
int
bias
=
16
;
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x04
));
// 0b00000100 = 2^-14
#else
}
static
constexpr
int
bias
=
15
;
// IEEE
#endif
};
template
<
class
T
>
// minumum finite value
struct
numeric
;
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
lowest
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0xfb
));
// 0b11111011 = -57344
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
max
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x7b
));
// 0b01111011 = 57344
}
// difference between 1.0 and next representable bf8 value (1.25)
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
epsilon
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x34
));
// 0.25
}
// rounding error (0.125)
// half of epsilon
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
round_error
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x30
));
// 0.125
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
infinity
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x7c
));
// 0b01111100
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
quiet_NaN
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x7F
));
// 0b01111111
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
signaling_NaN
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0xFF
));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
denorm_min
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x01
));
}
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
zero
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0
));
}
};
#else
template
<
>
template
<
>
struct
numeric
<
fp8_t
>
struct
numeric
<
fp8_t
>
{
{
...
@@ -811,6 +1054,7 @@ struct numeric<bf8_t>
...
@@ -811,6 +1054,7 @@ struct numeric<bf8_t>
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0
));
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0
));
}
}
};
};
#endif
#if CK_TILE_USE_CUSTOM_DATA_TYPE
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST_DEVICE
,
fp8_t
)
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST_DEVICE
,
fp8_t
)
...
@@ -818,19 +1062,26 @@ CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
...
@@ -818,19 +1062,26 @@ CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
#endif
#endif
// math
// math
CK_TILE_HOST_DEVICE
template
<
typename
T
>
fp8_t
abs
(
const
fp8_t
&
x
)
CK_TILE_HOST_DEVICE
T
abs
(
const
T
&
x
)
{
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
bit_cast
<
fp8_raw_t
>
(
x
)
&
0x7f
));
static_assert
(
std
::
is_same_v
<
T
,
fp8_t
>
||
std
::
is_same_v
<
T
,
bf8_t
>
,
"Only fp8_t and bf8_t are supported"
);
return
bit_cast
<
T
>
(
static_cast
<
uint8_t
>
(
bit_cast
<
uint8_t
>
(
x
)
&
numeric_traits
<
T
>::
abs_mask
));
}
}
CK_TILE_HOST_DEVICE
CK_TILE_HOST_DEVICE
bool
isnan
(
const
fp8_t
&
x
)
bool
isnan
(
const
fp8_t
&
x
)
{
{
uint8_t
xx
=
bit_cast
<
fp8_raw_t
>
(
x
);
uint8_t
xx
=
bit_cast
<
fp8_raw_t
>
(
x
);
return
xx
==
0x80
;
// TODO: NANOO
}
#if CK_TILE_USE_OCP_FP8
return
(
xx
&
0x7f
)
==
0x7f
;
#else
return
xx
==
0x80
;
#endif
}
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_DEVICE
CK_TILE_DEVICE
fp8_t
sqrt
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
fp8_t
sqrt
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
...
@@ -842,20 +1093,21 @@ fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); }
...
@@ -842,20 +1093,21 @@ fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); }
CK_TILE_DEVICE
CK_TILE_DEVICE
fp8_t
log
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
fp8_t
log
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
#endif
CK_TILE_HOST_DEVICE
bf8_t
abs
(
const
bf8_t
&
x
)
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
fp8_raw_t
>
(
bit_cast
<
bf8_raw_t
>
(
x
)
&
0x7f
));
}
CK_TILE_HOST_DEVICE
CK_TILE_HOST_DEVICE
bool
isnan
(
const
bf8_t
&
x
)
bool
isnan
(
const
bf8_t
&
x
)
{
{
uint8_t
xx
=
bit_cast
<
bf8_raw_t
>
(
x
);
uint8_t
xx
=
bit_cast
<
bf8_raw_t
>
(
x
);
return
xx
==
0x80
;
// TODO: NANOO
#if CK_TILE_USE_OCP_FP8
return
(
xx
&
0x7f
)
>
0x7c
;
#else
return
xx
==
0x80
;
#endif
}
}
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_DEVICE
CK_TILE_DEVICE
bf8_t
sqrt
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
bf8_t
sqrt
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
...
@@ -867,5 +1119,6 @@ bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); }
...
@@ -867,5 +1119,6 @@ bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); }
CK_TILE_DEVICE
CK_TILE_DEVICE
bf8_t
log
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
bf8_t
log
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
#endif
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/numeric/half.hpp
View file @
35aebe59
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
...
@@ -236,10 +236,11 @@ struct numeric_traits<half_t>
...
@@ -236,10 +236,11 @@ struct numeric_traits<half_t>
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
static
constexpr
uint16_t
exp_mask
=
0x1F
;
static
constexpr
uint16_t
exp_mask
=
0x1F
;
static
constexpr
uint32_t
Inf
=
0x7C00
;
static
constexpr
uint16_t
abs_mask
=
0x7FFF
;
static
constexpr
uint32_t
NegInf
=
0xFC00
;
static
constexpr
uint16_t
Inf
=
0x7C00
;
static
constexpr
uint32_t
NaN
=
0x7C01
;
static
constexpr
uint16_t
NegInf
=
0xFC00
;
static
constexpr
uint32_t
Neg0
=
0x8000
;
static
constexpr
uint16_t
NaN
=
0x7C01
;
static
constexpr
uint16_t
Neg0
=
0x8000
;
using
bitwise_type
=
uint16_t
;
using
bitwise_type
=
uint16_t
;
};
};
...
...
include/ck_tile/core/numeric/numeric.hpp
View file @
35aebe59
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -89,6 +89,7 @@ struct numeric_traits<float>
...
@@ -89,6 +89,7 @@ struct numeric_traits<float>
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
static
constexpr
uint32_t
exp_mask
=
0xFF
;
static
constexpr
uint32_t
exp_mask
=
0xFF
;
static
constexpr
uint32_t
abs_mask
=
0x7FFFFFFF
;
static
constexpr
uint32_t
Inf
=
0x7F800000
;
static
constexpr
uint32_t
Inf
=
0x7F800000
;
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment