Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f90f5da6
Commit
f90f5da6
authored
Oct 29, 2024
by
Rostyslav Geyyer
Browse files
Add scale type and mxfp conversions
parent
47f161dc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
618 additions
and
9 deletions
+618
-9
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+86
-9
include/ck/utility/e8m0_utils.hpp
include/ck/utility/e8m0_utils.hpp
+11
-0
include/ck/utility/mxf4_utils.hpp
include/ck/utility/mxf4_utils.hpp
+138
-0
include/ck/utility/mxfp_utils.hpp
include/ck/utility/mxfp_utils.hpp
+383
-0
No files found.
include/ck/utility/data_type.hpp
View file @
f90f5da6
...
...
@@ -7,12 +7,13 @@
namespace
ck
{
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
int4_t
=
_BitInt
(
4
);
using
f4_t
=
unsigned
_BitInt
(
4
);
using
f8_t
=
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
int4_t
=
_BitInt
(
4
);
using
f4_t
=
unsigned
_BitInt
(
4
);
using
f8_t
=
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
e8m0_scale_t
=
uint8_t
;
// vector_type
template
<
typename
T
,
index_t
N
>
...
...
@@ -1162,11 +1163,51 @@ struct NumericLimits<f4_t>
static
constexpr
uint8_t
binary_min_subnorm
=
0x1
;
// 0b0001
static
constexpr
uint8_t
binary_max_subnorm
=
0x1
;
// 0b0001
static
constexpr
float
data_max_normal_number
=
6
;
static
constexpr
float
data_min_subnormal_number
=
0.5
;
__host__
__device__
static
constexpr
f4_t
Min
()
{
return
f4_t
(
binary_min_normal
);
}
__host__
__device__
static
constexpr
f4_t
Max
()
{
return
f4_t
(
binary_max_normal
);
}
__host__
__device__
static
constexpr
f4_t
Lowest
()
{
return
f4_t
(
binary_lowest_normal
);
}
__host__
__device__
static
constexpr
f4_t
MinSubnorm
()
{
return
f4_t
(
binary_min_subnorm
);
}
__host__
__device__
static
constexpr
f4_t
MaxSubnorm
()
{
return
f4_t
(
binary_max_subnorm
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
template
<
>
struct
NumericLimits
<
e8m0_scale_t
>
{
static
constexpr
e8m0_scale_t
binary_min
=
0x00
;
// 0b00000000
static
constexpr
e8m0_scale_t
binary_max
=
0xFE
;
// 0b11111110
static
constexpr
e8m0_scale_t
binary_qnan
=
0xFF
;
// 0b11111111
static
constexpr
e8m0_scale_t
binary_1
=
0x7F
;
// 0b01111111
static
constexpr
e8m0_scale_t
binary_2
=
0x80
;
// 0b10000000
static
constexpr
e8m0_scale_t
binary_3
=
0x82
;
// 0b10000010
static
constexpr
e8m0_scale_t
binary_135
=
0x87
;
// 0b10000111
static
constexpr
e8m0_scale_t
binary_142
=
0x8E
;
// 0b10001110
__host__
__device__
static
constexpr
e8m0_scale_t
Min
()
{
return
e8m0_scale_t
(
binary_min
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Max
()
{
return
e8m0_scale_t
(
binary_max
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
QuietNaN
()
{
return
e8m0_scale_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_1
()
{
return
e8m0_scale_t
(
binary_1
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_2
()
{
return
e8m0_scale_t
(
binary_2
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_3
()
{
return
e8m0_scale_t
(
binary_3
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_135
()
{
return
e8m0_scale_t
(
binary_135
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_142
()
{
return
e8m0_scale_t
(
binary_142
);
}
};
template
<
typename
T
>
...
...
@@ -1229,8 +1270,44 @@ struct NumericUtils<bf8_t>
template
<
>
struct
NumericUtils
<
f4_t
>
{
static
constexpr
int
exp
=
2
;
static
constexpr
int
mant
=
1
;
static
constexpr
int
bias
=
1
;
static
constexpr
int
exp
=
2
;
static
constexpr
int
mant
=
1
;
static
constexpr
int
bias
=
1
;
static
constexpr
uint32_t
sr_shift
=
10
;
static
constexpr
int
unbiased_exp_min
=
0
;
static
constexpr
int
unbiased_exp_max
=
2
;
static
constexpr
int
biased_exp_min
=
1
;
static
constexpr
int
biased_exp_max
=
3
;
static
constexpr
uint8_t
positive_zero_mask
=
0b0000
;
static
constexpr
uint8_t
negative_zero_mask
=
0b1000
;
static
constexpr
uint8_t
one_mask
=
0b0010
;
static
constexpr
uint8_t
set_sign_mask
=
0b0111
;
static
constexpr
uint8_t
data_max_positive_normal_mask
=
0b0111
;
static
constexpr
uint8_t
data_max_negative_normal_mask
=
0b1111
;
static
constexpr
uint8_t
data_max_positive_subnormal_mask
=
0b0001
;
static
constexpr
uint8_t
data_max_negative_subnormal_mask
=
0b1001
;
using
bitwise_type
=
uint8_t
;
};
template
<
>
struct
NumericUtils
<
e8m0_scale_t
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
0
;
static
constexpr
int
bias
=
127
;
static
constexpr
int
unbiased_exp_min
=
-
127
;
static
constexpr
int
unbiased_exp_max
=
127
;
static
constexpr
int
biased_exp_min
=
0
;
static
constexpr
int
biased_exp_max
=
254
;
using
bitwise_type
=
uint8_t
;
};
}
// namespace ck
include/ck/utility/e8m0_utils.hpp
0 → 100644
View file @
f90f5da6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace
ck
::
utils
{
}
// namespace ck::utils
include/ck/utility/mxf4_utils.hpp
0 → 100644
View file @
f90f5da6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace
ck
{
__host__
inline
int
clz
(
uint32_t
x
)
{
return
__builtin_clz
(
x
);
}
__device__
inline
int
clz
(
uint32_t
x
)
{
return
__clz
(
x
);
}
}
// namespace ck
namespace
ck
::
utils
{
// template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
// __host__ __device__ Y cast_to_f8(X x, uint32_t rng)
// {
// // check datatypes
// constexpr bool is_half = std::is_same<X, half_t>::value;
// constexpr bool is_float = std::is_same<X, float>::value;
// static_assert(is_half || is_float, "Only half and float can be casted.");
// return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
// }
// template <typename X, typename Y, bool negative_zero_nan>
// __host__ __device__ Y cast_from_f8(X x)
// {
// // check datatype
// constexpr bool is_half = std::is_same<Y, half_t>::value;
// constexpr bool is_float = std::is_same<Y, float>::value;
// static_assert(is_half || is_float, "only half and float are supported.");
// return run_cast_from_f8<X, Y, negative_zero_nan>(x);
// }
template
<
>
__host__
__device__
inline
bool
is_nan
<
f4_t
>
(
e8m0_scale_t
const
scale
,
f4_t
const
dataBytes
[[
maybe_unused
]])
{
// no need to check for data as it does not have NaN representation
return
scale
==
NumericLimits
<
e8m0_scale_t
>::
QuietNaN
();
}
// no infinity representation in ocp_e2m1_mxfp4 will always return false
template
<
>
__host__
__device__
inline
bool
is_inf
<
f4_t
>
(
e8m0_scale_t
const
scale
[[
maybe_unused
]],
f4_t
const
data
[[
maybe_unused
]])
{
// no inf representation for ocp_e2m1_mxfp4
return
false
;
}
template
<
>
__host__
__device__
inline
bool
is_zero
<
f4_t
>
(
e8m0_scale_t
const
scale
,
f4_t
const
data
)
{
if
(
is_nan
<
f4_t
>
(
scale
,
data
))
return
false
;
// no need to check for scale as it does not have a 0 representation
f4_t
data
=
(
data
&
0b00001111
)
&
NumericUtils
<
e8m0_scale_t
>::
set_sign_mask
;
return
data
==
0b0
;
}
template
<
>
__host__
__device__
inline
float
to_float
<
f4_t
>
(
e8m0_scale_t
const
scale
,
f4_t
const
data
)
{
if
(
is_nan
<
f4_t
>
(
scale
,
data
))
return
std
::
numeric_limits
<
float
>::
quiet_NaN
();
if
(
is_zero
<
f4_t
>
(
scale
,
data
))
return
0.0
f
;
uint8_t
data
=
data
&
0b00001111
;
int
scale_exp
=
get_exponent_value
<
e8m0_scale_t
>
(
scale
);
return
convert_to_float
<
f4_t
>
(
data
,
scale_exp
);
}
template
<
>
__host__
__device__
inline
f4_t
sat_convert_to_type
<
f4_t
>
(
float
value
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
{
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
}
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f4_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
f4_t
res
=
convert_to_type
<
f4_t
>
(
value
);
if
(
std
::
abs
(
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
res
))
<
NumericUtils
<
f4_t
>::
DataMinSubnorm
())
return
value
<
0
?
NumericUtils
<
f4_t
>::
negative_zero_mask
:
NumericUtils
<
f4_t
>::
positive_zero_mask
;
return
res
;
}
template
<
>
__host__
__device__
inline
f4_t
sat_convert_to_type_sr
<
f4_t
>
(
float
value
,
uint32_t
seed
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f4_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
f4_t
res
=
convert_to_type_sr
<
f4_t
>
(
value
,
seed
);
if
(
std
::
abs
(
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
res
))
<
NumericUtils
<
f4_t
>::
DataMinSubnorm
())
return
value
<
0
?
NumericUtils
<
f4_t
>::
negative_zero_mask
:
NumericUtils
<
f4_t
>::
positive_zero_mask
;
return
res
;
}
}
// namespace ck::utils
include/ck/utility/mxfp_utils.hpp
0 → 100644
View file @
f90f5da6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
::
utils
{
union
cvt
{
float
value_float
;
uint32_t
value_bitwise
;
};
template
<
typename
DTYPE
>
inline
bool
getDataHasInf
()
{
return
DTYPE
::
dataInfo
.
hasInf
;
}
template
<
typename
T
>
__host__
__device__
inline
bool
is_zero
(
e8m0_scale_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
bool
is_nan
(
e8m0_scale_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
bool
is_inf
(
e8m0_scale_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
int
get_exponent_value
(
T
x
)
{
x
>>=
NumericUtils
<
T
>::
mant
;
x
&=
((
1
<<
NumericUtils
<
T
>::
exp
)
-
1
);
return
static_cast
<
int
>
(
x
);
}
template
<
typename
T
>
__host__
__device__
inline
bool
is_subnormal
(
T
x
)
{
return
get_exponent_value
<
T
>
(
x
)
==
0
;
}
template
<
typename
T
>
__host__
__device__
inline
double
get_mantissa_value
(
T
x
)
{
double
mantissa
=
is_subnormal
<
T
>
(
x
)
?
0.0
f
:
1.0
f
;
for
(
uint
i
=
0
;
i
<
NumericUtils
<
T
>::
mant
;
i
++
)
{
mantissa
+=
std
::
pow
(
2
,
-
int32_t
((
NumericUtils
<
T
>::
mant
-
i
)))
*
(
x
&
0b1
);
x
>>=
1
;
}
return
mantissa
;
}
template
<
typename
T
>
__host__
__device__
float
convert_to_float
(
T
data
,
int
scale_exp
)
{
float
d_sign
=
std
::
pow
(
-
1
,
static_cast
<
float
>
(
data
>>
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
t
>::
mant
)));
float
d_exp
;
if
(
is_subnormal
<
T
>
(
data
))
d_exp
=
std
::
pow
(
2
,
1
-
static_cast
<
int
>
(
NumericUtils
<
T
>::
bias
));
else
d_exp
=
std
::
pow
(
2
,
get_exponent_value
<
T
>
(
data
)
-
static_cast
<
int
>
(
NumericUtils
<
T
>::
bias
));
float
d_mant
=
get_mantissa_value
<
T
>
(
data
);
float
data_value
=
d_sign
*
d_exp
*
d_mant
;
float
scale_value
=
std
::
pow
(
2
,
static_cast
<
float
>
((
scale_exp
-
static_cast
<
int
>
(
NumericUtils
<
e8m0_scale_t
>::
bias
))));
return
data_value
*
scale_value
;
}
template
<
typename
T
>
__host__
__device__
inline
float
to_float
(
e8m0_scale_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
T
sat_convert_to_type
(
float
value
);
template
<
typename
T
>
__host__
__device__
T
sat_convert_to_type_sr
(
float
value
,
uint32_t
seed
);
template
<
typename
T
>
inline
T
convert_to_type
(
float
value
)
{
using
bitwise_type
=
NumericUtils
<
T
>::
bitwise_type
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
T
>::
Max
())
{
float
max_value
=
NumericLimits
<
T
>::
Max
();
cvt
t
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
max_value
;
uint32_t
max_bitwise
=
t
.
value_bitwise
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
value
;
bitwise_type
sign
=
t
.
value_bitwise
>>
(
NumericUtils
<
float
>::
exp
+
NumericUtils
<
float
>::
mant
);
bitwise_type
exp
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
)
-
(
NumericUtils
<
float
>::
bias
-
NumericUtils
<
T
>::
bias
);
bitwise_type
mantissa
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
mant_prev
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant_prev
&=
((
1
<<
NumericUtils
<
T
>::
mant
)
-
1
);
mant_prev
--
;
mant_prev
<<=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
prev_bit
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
<<
NumericUtils
<
float
>::
mant
)
|
mant_prev
;
t
.
value_bitwise
=
prev_bit
;
float
prev_val
=
t
.
value_float
;
float
diff
=
max_value
-
prev_val
;
float
actual_max
=
max_value
+
(
diff
/
2
);
if
(
std
::
abs
(
value
)
<
actual_max
)
{
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
exp
<<
NumericUtils
<
T
>::
mant
)
|
mantissa
;
}
else
{
if
(
!
get_data_has_inf
<
T
>
())
{
return
(
1
<<
(
NumericUtils
<
T
>::
mant
+
NumericUtils
<
T
>::
exp
))
-
1
;
}
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
}
const
int
mfmt
=
NumericUtils
<
float
>::
mant
;
uint32_t
x
;
// x = reinterpret_cast<uint32_t&>(value);
x
=
bit_cast
<
uint32_t
>
(
value
);
uint32_t
head
,
mantissa
;
int32_t
exponent
,
bias
;
uint32_t
sign
;
head
=
x
&
NumericUtils
<
float
>::
head_mask
;
mantissa
=
x
&
NumericUtils
<
float
>::
mant_mask
;
exponent
=
(
head
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
;
sign
=
head
>>
(
NumericUtils
<
float
>::
mant
+
NumericUtils
<
float
>::
exp
);
bias
=
NumericUtils
<
float
>::
bias
;
if
(
x
==
0
)
{
return
0b0
;
}
const
int
mini_bias
=
NumericUtils
<
T
>::
bias
;
const
int
mini_denormal_act_exponent
=
1
-
mini_bias
;
int
act_exponent
,
out_exponent
,
exponent_diff
;
bool
is_subnorm
=
false
;
if
(
exponent
==
0
)
{
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
mini_denormal_act_exponent
-
act_exponent
;
is_subnorm
=
true
;
}
else
{
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
mini_denormal_act_exponent
)
{
exponent_diff
=
mini_denormal_act_exponent
-
act_exponent
;
is_subnorm
=
true
;
}
else
{
exponent_diff
=
0
;
}
mantissa
+=
(
1UL
<<
mfmt
);
}
auto
shift_amount
=
(
mfmt
-
NumericUtils
<
T
>::
mant
+
exponent_diff
);
shift_amount
=
(
shift_amount
>=
64
)
?
63
:
shift_amount
;
bool
midpoint
=
(
mantissa
&
((
1UL
<<
shift_amount
)
-
1
))
==
(
1UL
<<
(
shift_amount
-
1
));
float
min_subnorm
=
NumericLimits
<
T
>
DataMinSubnorm
()
*
(
sign
?
-
1
:
1
);
if
(
isSubNorm
&&
std
::
abs
(
value
)
<
std
::
abs
(
min_subnorm
))
{
// closer to 0
if
(
std
::
abs
(
value
)
<=
std
::
abs
(
min_subnorm
-
value
))
return
0
;
else
return
1
|
(
sign
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
));
}
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
mfmt
);
out_exponent
=
(
act_exponent
+
exponent_diff
)
+
mini_bias
-
(
implicit_one
?
0
:
1
);
uint32_t
drop_mask
=
(
1UL
<<
(
mfmt
-
NumericUtils
<
T
>::
mant
))
-
1
;
bool
odd
=
mantissa
&
(
1UL
<<
(
mfmt
-
NumericUtils
<
T
>::
mant
));
mantissa
+=
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
)
&
drop_mask
;
if
(
out_exponent
==
0
)
{
if
((
1UL
<<
mfmt
)
&
mantissa
)
{
out_exponent
=
1
;
}
}
else
{
if
((
1UL
<<
(
mfmt
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
out_exponent
++
;
}
}
mantissa
>>=
(
mfmt
-
NumericUtils
<
T
>::
mant
);
if
(
out_exponent
==
0
&&
mantissa
==
0
)
{
return
0
;
}
mantissa
&=
(
1UL
<<
NumericUtils
<
T
>::
mant
)
-
1
;
return
(
sign
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
out_exponent
<<
NumericUtils
<
T
>::
mant
)
|
mantissa
;
}
template
<
typename
T
>
inline
T
convert_to_type_sr
(
float
value
,
uint32_t
seed
)
{
using
bitwise_type
=
NumericUtils
<
T
>::
bitwise_type
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
T
>::
Max
())
{
float
max_value
=
NumericLimits
<
T
>::
Max
();
cvt
t
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
max_value
;
uint
max_bitwise
=
t
.
value_bitwise
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
value
;
T
sign
=
t
.
value_bitwise
>>
(
NumericUtils
<
float
>::
exp
+
NumericUtils
<
float
>::
mant
);
T
exp
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
)
-
(
NumericUtils
<
float
>::
bias
-
NumericUtils
<
T
>::
bias
);
uint32_t
mant_prev
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant_prev
&=
((
1UL
<<
NumericUtils
<
T
>::
mant
)
-
1
);
mant_prev
--
;
mant_prev
<<=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
prev_bit
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
<<
NumericUtils
<
float
>::
mant
)
|
mant_prev
;
t
.
value_bitwise
=
prev_bit
;
float
prev_val
=
t
.
value_float
;
float
diff
=
max_value
-
prev_val
;
float
actual_max
=
max_value
+
(
diff
/
2
);
if
(
std
::
abs
(
value
)
<
actual_max
)
{
double
d_max_value
=
static_cast
<
double
>
(
max_value
);
double
d_actual_max
=
static_cast
<
double
>
(
actual_max
);
double
d_value
=
static_cast
<
double
>
(
value
);
double
d_is
=
std
::
abs
(
d_max_value
-
d_actual_max
);
double
d_seed
=
static_cast
<
double
>
(
seed
);
double
d_prob
=
1.0
f
-
(
std
::
abs
(
d_value
-
d_max_value
)
/
d_is
);
// prob to round down
double
thresh
=
UINT_MAX
*
d_prob
;
if
(
!
get_data_has_inf
<
T
>
()
||
d_seed
<=
thresh
)
// return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time
return
sign
==
0
?
NumericUtils
<
f4_t
>
data_max_positive_normal_mask
:
NumericUtils
<
f4_t
>
data_max_negative_normal_mask
;
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
// inf
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
else
{
if
(
!
get_data_has_inf
<
T
>
())
return
(
1
<<
(
NumericUtils
<
T
>::
mant
+
NumericUtils
<
T
>::
exp
))
-
1
;
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
// inf
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
}
// uint32_t f32 = reinterpret_cast<uint32_t&>(value);
uint32_t
f32
=
bit_cast
<
uint32_t
>
(
value
);
auto
f32_mant
=
f32
&
NumericUtils
<
float
>::
mant_mask
;
auto
head
=
f32
&
NumericUtils
<
float
>::
head_mask
;
auto
f32_exp
=
(
head
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
;
auto
sign_bit
=
head
>>
(
NumericUtils
<
float
>::
mant
+
NumericUtils
<
float
>::
exp
);
auto
sign
=
sign_bit
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
);
f32_exp
=
(
int32_t
)
f32exp
-
NumericUtils
<
float
>::
bias
;
int32_t
exp
=
f32_exp
;
auto
mant
=
f32_mant
;
bool
subnorm
=
false
;
if
(
value
==
0
)
return
0b0
;
if
(
exp
>=
NumericUtils
<
T
>::
unbiased_exp_min
)
{
mant
=
f32_mant
;
}
// if the exponent bit is 8, then the subnormal is exactly the same as f32
else
if
(
exp
<
NumericUtils
<
T
>::
unbiased_exp_min
&&
NumericUtils
<
T
>::
exp
<
NumericUtils
<
float
>::
exp
)
{
subnorm
=
true
;
auto
diff
=
(
uint32_t
)(
NumericUtils
<
T
>::
unbiased_exp_min
-
exp
);
if
(
diff
>=
32
)
{
mant
=
0
;
f32_mant
=
0
;
}
else
{
f32_mant
|=
(
uint32_t
)
1
<<
NumericUtils
<
float
>::
mant
;
f32_mant
>>=
diff
;
}
exp
=
0
;
mant
=
f32_mant
;
}
uint32_t
sr_shift
=
NumericUtils
<
T
>::
sr_shift
;
// For stochastic-rounding we add the aligned random value to the
// mantissa and then truncate (RTZ).
mant
+=
seed
>>
sr_shift
;
// Increment exponent when mantissa overflows due to rounding
if
(
mant
>=
(
uint32_t
)
1
<<
NumericUtils
<
float
>::
mant
)
++
exp
;
mant
>>=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant
&=
((
1
<<
NumericUtils
<
T
>::
mant
)
-
1
);
auto
biased_exp
=
(
uint32_t
)
exp
;
if
(
!
subnorm
)
biased_exp
=
(
uint32_t
)(
exp
+
NumericUtils
<
T
>::
bias
);
biased_exp
&=
((
1
<<
NumericUtils
<
T
>::
exp
)
-
1
);
auto
val
=
sign
|
biased_exp
<<
NumericUtils
<
T
>::
mant
|
mant
;
return
val
;
}
}
// namespace ck::utils
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment