Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b2b68eff
Commit
b2b68eff
authored
Oct 19, 2023
by
Rostyslav Geyyer
Browse files
Fix the conversion
parent
bf0addb5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
111 additions
and
100 deletions
+111
-100
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+7
-21
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+98
-53
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+6
-26
No files found.
include/ck/utility/data_type.hpp
View file @
b2b68eff
...
@@ -12,12 +12,8 @@ using half_t = _Float16;
...
@@ -12,12 +12,8 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
using
int4_t
=
_BitInt
(
4
);
#endif
#endif
#if defined CK_ENABLE_FP8
using
f8_t
=
_BitInt
(
8
);
using
f8_t
=
_BitInt
(
8
);
#endif
#if defined CK_ENABLE_BF8
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
#endif
// vector_type
// vector_type
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
...
@@ -148,23 +144,19 @@ struct scalar_type<int4_t>
...
@@ -148,23 +144,19 @@ struct scalar_type<int4_t>
};
};
#endif
#endif
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
scalar_type
<
f8_t
>
struct
scalar_type
<
f8_t
>
{
{
using
type
=
f8_t
;
using
type
=
f8_t
;
static
constexpr
index_t
vector_size
=
1
;
static
constexpr
index_t
vector_size
=
1
;
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
scalar_type
<
bf8_t
>
struct
scalar_type
<
bf8_t
>
{
{
using
type
=
bf8_t
;
using
type
=
bf8_t
;
static
constexpr
index_t
vector_size
=
1
;
static
constexpr
index_t
vector_size
=
1
;
};
};
#endif
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
struct
vector_type
<
T
,
1
>
...
@@ -968,24 +960,20 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
...
@@ -968,24 +960,20 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// f8
// f8
#if defined CK_ENABLE_FP8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
#endif
// bf8
// bf8
#if defined CK_ENABLE_BF8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
#endif
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
...
@@ -1033,7 +1021,6 @@ struct NumericLimits<int4_t>
...
@@ -1033,7 +1021,6 @@ struct NumericLimits<int4_t>
};
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
NumericLimits
<
f8_t
>
struct
NumericLimits
<
f8_t
>
{
{
...
@@ -1056,9 +1043,7 @@ struct NumericLimits<f8_t>
...
@@ -1056,9 +1043,7 @@ struct NumericLimits<f8_t>
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
NumericLimits
<
bf8_t
>
struct
NumericLimits
<
bf8_t
>
{
{
...
@@ -1081,7 +1066,6 @@ struct NumericLimits<bf8_t>
...
@@ -1081,7 +1066,6 @@ struct NumericLimits<bf8_t>
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
};
};
#endif
template
<
typename
T
>
template
<
typename
T
>
struct
NumericUtils
struct
NumericUtils
...
@@ -1093,6 +1077,7 @@ struct NumericUtils<float>
...
@@ -1093,6 +1077,7 @@ struct NumericUtils<float>
{
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
23
;
static
constexpr
int
mant
=
23
;
static
constexpr
int
bias
=
127
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
...
@@ -1109,6 +1094,7 @@ struct NumericUtils<half_t>
...
@@ -1109,6 +1094,7 @@ struct NumericUtils<half_t>
{
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
10
;
static
constexpr
int
mant
=
10
;
static
constexpr
int
bias
=
15
;
static
constexpr
uint16_t
nan_mask
=
0x7C00
;
static
constexpr
uint16_t
nan_mask
=
0x7C00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
...
@@ -1120,22 +1106,22 @@ struct NumericUtils<half_t>
...
@@ -1120,22 +1106,22 @@ struct NumericUtils<half_t>
using
bitwise_type
=
uint16_t
;
using
bitwise_type
=
uint16_t
;
};
};
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
NumericUtils
<
f8_t
>
struct
NumericUtils
<
f8_t
>
{
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
8
;
// negative zero nan mode
// static constexpr int bias = 7; // ieee mode
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
NumericUtils
<
bf8_t
>
struct
NumericUtils
<
bf8_t
>
{
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
};
};
#endif
}
// namespace ck
}
// namespace ck
include/ck/utility/f8_utils.hpp
View file @
b2b68eff
...
@@ -5,9 +5,6 @@
...
@@ -5,9 +5,6 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace
ck
{
namespace
ck
{
// fp8 rounding modes
// fp8 rounding modes
...
@@ -19,6 +16,9 @@ enum class f8_rounding_mode
...
@@ -19,6 +16,9 @@ enum class f8_rounding_mode
stochastic
stochastic
};
};
__host__
inline
int
clz
(
uint32_t
x
)
{
return
__builtin_clz
(
x
);
}
__device__
inline
int
clz
(
uint32_t
x
)
{
return
__clz
(
x
);
}
}
// namespace ck
}
// namespace ck
namespace
ck
::
utils
{
namespace
ck
::
utils
{
...
@@ -36,7 +36,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
...
@@ -36,7 +36,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
int
exponent
;
int
exponent
,
bias
;
uint32_t
head
,
mantissa
,
sign
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
// nan code is same for float and half
constexpr
Y
nan_code
=
0x80
;
constexpr
Y
nan_code
=
0x80
;
...
@@ -51,12 +51,11 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
...
@@ -51,12 +51,11 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
mantissa
=
x_bitwise
&
NumericUtils
<
X
>::
mant_mask
;
mantissa
=
x_bitwise
&
NumericUtils
<
X
>::
mant_mask
;
exponent
=
(
head
>>
in_mant
)
&
NumericUtils
<
X
>::
exp_mask
;
exponent
=
(
head
>>
in_mant
)
&
NumericUtils
<
X
>::
exp_mask
;
sign
=
head
>>
(
in_exp
+
in_mant
);
sign
=
head
>>
(
in_exp
+
in_mant
);
bias
=
NumericUtils
<
X
>::
bias
;
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
exp_low_cutoff
=
(
1
<<
(
in_exp
-
1
))
-
(
1
<<
(
out_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
if
constexpr
(
negative_zero_nan
)
if
constexpr
(
negative_zero_nan
)
{
{
...
@@ -69,56 +68,107 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
...
@@ -69,56 +68,107 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
}
// if input is half and output is bf8
if
((
NumericUtils
<
X
>::
mant
==
10
)
&&
(
NumericUtils
<
Y
>::
mant
==
2
)
&&
negative_zero_nan
&&
exponent
==
0
)
{
exponent
+=
1
;
while
(
mantissa
<
(
1
<<
in_mant
))
{
mantissa
<<=
1
;
exponent
-=
1
;
}
mantissa
&=
~
(
1
<<
in_mant
);
}
// check if x is 0.0
// check if x is 0.0
if
(
x_bitwise
==
0
)
if
(
x_bitwise
==
0
)
return
0
;
return
0
;
exponent
-=
exp_low_cutoff
-
1
;
// First need to check if it is normal or denorm as there is a difference of implict 1
if
(
exponent
<=
0
)
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
+
1
-
exponent
))
-
1
;
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
mantissa
+=
1
<<
in_mant
;
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// apply random number if needed
// exponent and mantissa again3
mantissa
+=
(
stoch
?
rng
:
mantissa
)
&
drop_mask
;
if
(
mantissa
>=
(
2
<<
in_mant
))
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
{
const
int
out_bias
=
(
1
<<
(
out_exp
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
mantissa
>>=
1
;
const
int
out_denormal_act_exponent
=
1
-
out_bias
;
// actual exponent of f8 denormal
exponent
++
;
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// out_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
out_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
out_denormal_act_exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa
+=
(
1
<<
in_mant
);
// Add the implicit 1 into mantissa
}
}
mantissa
>>=
(
in_mant
-
out_mant
);
// check negative exponent
bool
midpoint
=
(
mantissa
&
((
1
<<
(
in_mant
-
out_mant
+
exponent_diff
))
-
1
))
==
if
(
exponent
<=
0
)
(
1
<<
(
in_mant
-
out_mant
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint. */
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
in_mant
);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
out_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
out_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
bool
odd
=
mantissa
&
(
1
<<
(
in_mant
-
out_mant
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
out_exponent
==
0
)
{
{
if
(
x_bitwise
==
0
)
if
((
1
<<
in_mant
)
&
mantissa
)
return
0
;
else
{
{
// subnormal range; represented by a subnormal float8 (exponent 0)
out_exponent
=
1
;
// denormal overflow to become normal, promote exponent
// and involves loss of accuracy
// No need to make 1 implicit now as it will be addressed later
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
}
}
}
// above range: quantize to maximum possible float of the same sign
else
else
if
(
exponent
>
max_exp
)
{
if
((
1
<<
(
in_mant
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
out_exponent
++
;
// No need to make 1 implicit now as it will be addressed later
}
}
mantissa
>>=
(
in_mant
-
out_mant
);
if
(
out_exponent
>
max_exp
)
{
{
if
(
clip
)
if
(
clip
)
{
{
mantissa
=
(
1
<<
out_mant
)
-
1
;
mantissa
=
(
1
<<
out_mant
)
-
1
;
exponent
=
max_exp
;
out_
exponent
=
max_exp
;
}
}
else
else
{
{
...
@@ -127,10 +177,10 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
...
@@ -127,10 +177,10 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
}
}
// check if x is 0.0 or -0.0
// check if x is 0.0 or -0.0
if
(
exponent
==
0
&&
mantissa
==
0
)
if
(
out_
exponent
==
0
&&
mantissa
==
0
)
return
negative_zero_nan
?
0
:
(
sign
<<
(
out_exp
+
out_mant
));
return
negative_zero_nan
?
0
:
(
sign
<<
(
out_exp
+
out_mant
));
mantissa
&=
(
1
<<
out_mant
)
-
1
;
mantissa
&=
(
1
<<
out_mant
)
-
1
;
return
(
sign
<<
(
out_exp
+
out_mant
))
|
(
exponent
<<
out_mant
)
|
mantissa
;
return
(
sign
<<
(
out_exp
+
out_mant
))
|
(
out_
exponent
<<
out_mant
)
|
mantissa
;
}
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
...
@@ -196,12 +246,9 @@ __host__ __device__ Y run_cast_from_f8(X x)
...
@@ -196,12 +246,9 @@ __host__ __device__ Y run_cast_from_f8(X x)
if
(
exponent
==
0
)
if
(
exponent
==
0
)
{
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
exponent
++
;
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
in_mant
);
while
(
mantissa
<
(
1
<<
in_mant
))
mantissa
<<=
sh
;
{
exponent
+=
1
-
sh
;
mantissa
<<=
1
;
exponent
--
;
}
mantissa
&=
((
1
<<
in_mant
)
-
1
);
mantissa
&=
((
1
<<
in_mant
)
-
1
);
}
}
exponent
+=
exp_low_cutoff
-
1
;
exponent
+=
exp_low_cutoff
-
1
;
...
@@ -244,5 +291,3 @@ __host__ __device__ Y cast_from_f8(X x)
...
@@ -244,5 +291,3 @@ __host__ __device__ Y cast_from_f8(X x)
}
}
}
// namespace ck::utils
}
// namespace ck::utils
#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
include/ck/utility/type_convert.hpp
View file @
b2b68eff
...
@@ -95,7 +95,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
...
@@ -95,7 +95,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
}
#if defined CK_ENABLE_FP8
// convert fp32 to fp8
// convert fp32 to fp8
template
<
>
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
...
@@ -146,7 +145,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
...
@@ -146,7 +145,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
// convert to float and use native converion
return
type_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
type_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
...
@@ -154,8 +153,6 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
...
@@ -154,8 +153,6 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
#else
return
type_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
...
@@ -166,16 +163,12 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
...
@@ -166,16 +163,12 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
#else
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8
// convert fp32 to bf8
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
float
>
(
float
x
)
...
@@ -226,7 +219,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
...
@@ -226,7 +219,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
// convert to float and use native converion
return
type_convert
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
return
type_convert
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
...
@@ -234,8 +227,6 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
...
@@ -234,8 +227,6 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
#else
return
type_convert
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
...
@@ -246,14 +237,11 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
...
@@ -246,14 +237,11 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
bf8_t
,
half_t
,
negative_zero_nan
>
(
x
);
#else
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
#endif
// Declare a template function for bf16 conversion using RTN
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
...
@@ -316,7 +304,6 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
...
@@ -316,7 +304,6 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
#if defined CK_ENABLE_FP8
// convert fp32 to fp8 with stochastic rounding
// convert fp32 to fp8 with stochastic rounding
template
<
>
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
...
@@ -352,7 +339,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
...
@@ -352,7 +339,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
...
@@ -361,13 +348,9 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
...
@@ -361,13 +348,9 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
#else
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8 with stochastic rounding
// convert fp32 to bf8 with stochastic rounding
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
...
@@ -403,7 +386,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
...
@@ -403,7 +386,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
...
@@ -413,10 +396,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
...
@@ -413,10 +396,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
#else
return
f8_convert_sr
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
#endif
}
// namespace ck
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment