Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
48d58131
Commit
48d58131
authored
Nov 22, 2024
by
Rostyslav Geyyer
Browse files
Rename E8M0 type
parent
b1ad4b4f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
95 additions
and
98 deletions
+95
-98
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+25
-28
include/ck/utility/e8m0_utils.hpp
include/ck/utility/e8m0_utils.hpp
+6
-6
include/ck/utility/mxf4_utils.hpp
include/ck/utility/mxf4_utils.hpp
+8
-8
include/ck/utility/mxfp_utils.hpp
include/ck/utility/mxfp_utils.hpp
+5
-5
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+44
-44
test/data_type/test_fp4.cpp
test/data_type/test_fp4.cpp
+7
-7
No files found.
include/ck/utility/data_type.hpp
View file @
48d58131
...
@@ -14,15 +14,15 @@ using f4_t = unsigned _BitInt(4);
...
@@ -14,15 +14,15 @@ using f4_t = unsigned _BitInt(4);
using
f8_t
=
_BitInt
(
8
);
using
f8_t
=
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
struct
e8m0_
scale
_t
struct
e8m0_
bexp
_t
{
{
// E8M0 scale is biased
// E8M0 scale is biased
using
type
=
uint8_t
;
using
type
=
uint8_t
;
type
data
;
type
data
;
constexpr
e8m0_
scale
_t
()
:
data
{
type
{}}
{}
constexpr
e8m0_
bexp
_t
()
:
data
{
type
{}}
{}
constexpr
e8m0_
scale
_t
(
type
init
)
:
data
{
init
}
{}
constexpr
e8m0_
bexp
_t
(
type
init
)
:
data
{
init
}
{}
bool
operator
==
(
const
e8m0_
scale
_t
&
other
)
const
{
return
(
data
==
other
.
data
);
}
bool
operator
==
(
const
e8m0_
bexp
_t
&
other
)
const
{
return
(
data
==
other
.
data
);
}
};
};
struct
f4x2_pk_t
struct
f4x2_pk_t
...
@@ -1813,33 +1813,30 @@ struct NumericLimits<f4_t>
...
@@ -1813,33 +1813,30 @@ struct NumericLimits<f4_t>
};
};
template
<
>
template
<
>
struct
NumericLimits
<
e8m0_
scale
_t
>
struct
NumericLimits
<
e8m0_
bexp
_t
>
{
{
static
constexpr
e8m0_scale_t
binary_min
=
0x00
;
// 0b00000000
static
constexpr
e8m0_bexp_t
binary_min
=
0x00
;
// 0b00000000
static
constexpr
e8m0_scale_t
binary_max
=
0xFE
;
// 0b11111110
static
constexpr
e8m0_bexp_t
binary_max
=
0xFE
;
// 0b11111110
static
constexpr
e8m0_scale_t
binary_qnan
=
0xFF
;
// 0b11111111
static
constexpr
e8m0_bexp_t
binary_qnan
=
0xFF
;
// 0b11111111
static
constexpr
e8m0_scale_t
binary_1
=
0x7F
;
// 0b01111111
static
constexpr
e8m0_bexp_t
binary_1
=
0x7F
;
// 0b01111111
static
constexpr
e8m0_scale_t
binary_2
=
0x80
;
// 0b10000000
static
constexpr
e8m0_bexp_t
binary_2
=
0x80
;
// 0b10000000
static
constexpr
e8m0_scale_t
binary_3
=
0x82
;
// 0b10000010
static
constexpr
e8m0_bexp_t
binary_3
=
0x82
;
// 0b10000010
static
constexpr
e8m0_scale_t
binary_135
=
0x87
;
// 0b10000111
static
constexpr
e8m0_bexp_t
binary_135
=
0x87
;
// 0b10000111
static
constexpr
e8m0_scale_t
binary_142
=
0x8E
;
// 0b10001110
static
constexpr
e8m0_bexp_t
binary_142
=
0x8E
;
// 0b10001110
__host__
__device__
static
constexpr
e8m0_scale_t
Min
()
{
return
e8m0_scale_t
(
binary_min
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Min
()
{
return
e8m0_bexp_t
(
binary_min
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Max
()
{
return
e8m0_scale_t
(
binary_max
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Max
()
{
return
e8m0_bexp_t
(
binary_max
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
QuietNaN
()
__host__
__device__
static
constexpr
e8m0_bexp_t
QuietNaN
()
{
return
e8m0_bexp_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_1
()
{
return
e8m0_bexp_t
(
binary_1
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_2
()
{
return
e8m0_bexp_t
(
binary_2
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_3
()
{
return
e8m0_bexp_t
(
binary_3
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_135
()
{
{
return
e8m0_
scale
_t
(
binary_
qnan
);
return
e8m0_
bexp
_t
(
binary_
135
);
}
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_1
()
{
return
e8m0_scale_t
(
binary_1
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_142
()
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_2
()
{
return
e8m0_scale_t
(
binary_2
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_3
()
{
return
e8m0_scale_t
(
binary_3
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_135
()
{
{
return
e8m0_scale_t
(
binary_135
);
return
e8m0_bexp_t
(
binary_142
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_142
()
{
return
e8m0_scale_t
(
binary_142
);
}
}
};
};
...
@@ -1944,7 +1941,7 @@ struct NumericUtils<f4_t>
...
@@ -1944,7 +1941,7 @@ struct NumericUtils<f4_t>
};
};
template
<
>
template
<
>
struct
NumericUtils
<
e8m0_
scale
_t
>
struct
NumericUtils
<
e8m0_
bexp
_t
>
{
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
0
;
static
constexpr
int
mant
=
0
;
...
...
include/ck/utility/e8m0_utils.hpp
View file @
48d58131
...
@@ -8,24 +8,24 @@
...
@@ -8,24 +8,24 @@
namespace
ck
::
utils
{
namespace
ck
::
utils
{
__host__
__device__
inline
float
cast_to_float
(
e8m0_
scale
_t
const
scale
)
__host__
__device__
inline
float
cast_to_float
(
e8m0_
bexp
_t
const
bexp
)
{
{
// TODO: check performance and try bit shift impl
// TODO: check performance and try bit shift impl
return
std
::
powf
(
2
,
bit_cast
<
uint8_t
>
(
scale
)
-
NumericUtils
<
e8m0_
scale
_t
>::
bias
);
return
std
::
powf
(
2
,
bit_cast
<
uint8_t
>
(
bexp
)
-
NumericUtils
<
e8m0_
bexp
_t
>::
bias
);
}
}
__host__
__device__
inline
e8m0_
scale
_t
cast_from_float
(
float
const
scale
)
__host__
__device__
inline
e8m0_
bexp
_t
cast_from_float
(
float
const
scale
)
{
{
uint32_t
e
=
bit_cast
<
uint32_t
>
(
scale
)
&
NumericUtils
<
float
>::
nan_mask
;
uint32_t
e
=
bit_cast
<
uint32_t
>
(
scale
)
&
NumericUtils
<
float
>::
nan_mask
;
return
static_cast
<
uint8_t
>
(
e
>>
23
);
return
static_cast
<
uint8_t
>
(
e
>>
23
);
}
}
template
<
>
template
<
>
__host__
__device__
inline
int
get_exponent_value
<
e8m0_
scale
_t
>
(
e8m0_
scale
_t
x
)
__host__
__device__
inline
int
get_exponent_value
<
e8m0_
bexp
_t
>
(
e8m0_
bexp
_t
x
)
{
{
x
.
data
>>=
NumericUtils
<
e8m0_
scale
_t
>::
mant
;
x
.
data
>>=
NumericUtils
<
e8m0_
bexp
_t
>::
mant
;
x
.
data
&=
((
1
<<
NumericUtils
<
e8m0_
scale
_t
>::
exp
)
-
1
);
x
.
data
&=
((
1
<<
NumericUtils
<
e8m0_
bexp
_t
>::
exp
)
-
1
);
return
static_cast
<
int
>
(
x
.
data
);
return
static_cast
<
int
>
(
x
.
data
);
}
}
...
...
include/ck/utility/mxf4_utils.hpp
View file @
48d58131
...
@@ -9,16 +9,16 @@
...
@@ -9,16 +9,16 @@
namespace
ck
::
utils
{
namespace
ck
::
utils
{
template
<
>
template
<
>
__host__
__device__
inline
bool
is_nan
<
f4_t
>
(
e8m0_
scale
_t
const
scale
,
__host__
__device__
inline
bool
is_nan
<
f4_t
>
(
e8m0_
bexp
_t
const
scale
,
f4_t
const
dataBytes
[[
maybe_unused
]])
f4_t
const
dataBytes
[[
maybe_unused
]])
{
{
// no need to check for data as it does not have NaN representation
// no need to check for data as it does not have NaN representation
return
scale
==
NumericLimits
<
e8m0_
scale
_t
>::
QuietNaN
();
return
scale
==
NumericLimits
<
e8m0_
bexp
_t
>::
QuietNaN
();
}
}
// no infinity representation in ocp_e2m1_mxfp4 will always return false
// no infinity representation in ocp_e2m1_mxfp4 will always return false
template
<
>
template
<
>
__host__
__device__
inline
bool
is_inf
<
f4_t
>
(
e8m0_
scale
_t
const
scale
[[
maybe_unused
]],
__host__
__device__
inline
bool
is_inf
<
f4_t
>
(
e8m0_
bexp
_t
const
scale
[[
maybe_unused
]],
f4_t
const
data
[[
maybe_unused
]])
f4_t
const
data
[[
maybe_unused
]])
{
{
// no inf representation for ocp_e2m1_mxfp4
// no inf representation for ocp_e2m1_mxfp4
...
@@ -26,7 +26,7 @@ __host__ __device__ inline bool is_inf<f4_t>(e8m0_scale_t const scale [[maybe_un
...
@@ -26,7 +26,7 @@ __host__ __device__ inline bool is_inf<f4_t>(e8m0_scale_t const scale [[maybe_un
}
}
template
<
>
template
<
>
__host__
__device__
inline
bool
is_zero
<
f4_t
>
(
e8m0_
scale
_t
const
scale
,
f4_t
const
data
)
__host__
__device__
inline
bool
is_zero
<
f4_t
>
(
e8m0_
bexp
_t
const
scale
,
f4_t
const
data
)
{
{
if
(
is_nan
<
f4_t
>
(
scale
,
data
))
if
(
is_nan
<
f4_t
>
(
scale
,
data
))
return
false
;
return
false
;
...
@@ -38,7 +38,7 @@ __host__ __device__ inline bool is_zero<f4_t>(e8m0_scale_t const scale, f4_t con
...
@@ -38,7 +38,7 @@ __host__ __device__ inline bool is_zero<f4_t>(e8m0_scale_t const scale, f4_t con
}
}
template
<
>
template
<
>
__host__
__device__
inline
float
to_float
<
f4_t
>
(
e8m0_
scale
_t
const
scale
,
f4_t
const
data
)
__host__
__device__
inline
float
to_float
<
f4_t
>
(
e8m0_
bexp
_t
const
scale
,
f4_t
const
data
)
{
{
if
(
is_nan
<
f4_t
>
(
scale
,
data
))
if
(
is_nan
<
f4_t
>
(
scale
,
data
))
return
std
::
numeric_limits
<
float
>::
quiet_NaN
();
return
std
::
numeric_limits
<
float
>::
quiet_NaN
();
...
@@ -48,7 +48,7 @@ __host__ __device__ inline float to_float<f4_t>(e8m0_scale_t const scale, f4_t c
...
@@ -48,7 +48,7 @@ __host__ __device__ inline float to_float<f4_t>(e8m0_scale_t const scale, f4_t c
f4_t
prepared_data
=
data
&
0b00001111
;
f4_t
prepared_data
=
data
&
0b00001111
;
int
scale_exp
=
get_exponent_value
<
e8m0_
scale
_t
>
(
scale
);
int
scale_exp
=
get_exponent_value
<
e8m0_
bexp
_t
>
(
scale
);
return
convert_to_float
<
f4_t
>
(
prepared_data
,
scale_exp
);
return
convert_to_float
<
f4_t
>
(
prepared_data
,
scale_exp
);
}
}
...
@@ -73,7 +73,7 @@ __host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
...
@@ -73,7 +73,7 @@ __host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
f4_t
res
=
convert_to_type
<
f4_t
>
(
value
);
f4_t
res
=
convert_to_type
<
f4_t
>
(
value
);
if
(
std
::
abs
(
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
res
))
<
if
(
std
::
abs
(
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f4_t
>::
DataMinSubnorm
())
NumericLimits
<
f4_t
>::
DataMinSubnorm
())
return
value
<
0
?
NumericUtils
<
f4_t
>::
negative_zero_mask
return
value
<
0
?
NumericUtils
<
f4_t
>::
negative_zero_mask
:
NumericUtils
<
f4_t
>::
positive_zero_mask
;
:
NumericUtils
<
f4_t
>::
positive_zero_mask
;
...
@@ -98,7 +98,7 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32
...
@@ -98,7 +98,7 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32
f4_t
res
=
convert_to_type_sr
<
f4_t
>
(
value
,
seed
);
f4_t
res
=
convert_to_type_sr
<
f4_t
>
(
value
,
seed
);
if
(
std
::
abs
(
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
res
))
<
if
(
std
::
abs
(
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f4_t
>::
DataMinSubnorm
())
NumericLimits
<
f4_t
>::
DataMinSubnorm
())
return
value
<
0
?
NumericUtils
<
f4_t
>::
negative_zero_mask
return
value
<
0
?
NumericUtils
<
f4_t
>::
negative_zero_mask
:
NumericUtils
<
f4_t
>::
positive_zero_mask
;
:
NumericUtils
<
f4_t
>::
positive_zero_mask
;
...
...
include/ck/utility/mxfp_utils.hpp
View file @
48d58131
...
@@ -18,13 +18,13 @@ inline bool getDataHasInf()
...
@@ -18,13 +18,13 @@ inline bool getDataHasInf()
}
}
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
inline
bool
is_zero
(
e8m0_
scale
_t
const
scale
,
T
const
data
);
__host__
__device__
inline
bool
is_zero
(
e8m0_
bexp
_t
const
scale
,
T
const
data
);
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
inline
bool
is_nan
(
e8m0_
scale
_t
const
scale
,
T
const
data
);
__host__
__device__
inline
bool
is_nan
(
e8m0_
bexp
_t
const
scale
,
T
const
data
);
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
inline
bool
is_inf
(
e8m0_
scale
_t
const
scale
,
T
const
data
);
__host__
__device__
inline
bool
is_inf
(
e8m0_
bexp
_t
const
scale
,
T
const
data
);
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
inline
int
get_exponent_value
(
T
x
)
__host__
__device__
inline
int
get_exponent_value
(
T
x
)
...
@@ -79,13 +79,13 @@ __host__ __device__ float convert_to_float(T data, int scale_exp)
...
@@ -79,13 +79,13 @@ __host__ __device__ float convert_to_float(T data, int scale_exp)
float
data_value
=
d_sign
*
d_exp
*
d_mant
;
float
data_value
=
d_sign
*
d_exp
*
d_mant
;
float
scale_value
=
std
::
pow
(
float
scale_value
=
std
::
pow
(
2
,
static_cast
<
float
>
((
scale_exp
-
static_cast
<
int
>
(
NumericUtils
<
e8m0_
scale
_t
>::
bias
))));
2
,
static_cast
<
float
>
((
scale_exp
-
static_cast
<
int
>
(
NumericUtils
<
e8m0_
bexp
_t
>::
bias
))));
return
data_value
*
scale_value
;
return
data_value
*
scale_value
;
}
}
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
inline
float
to_float
(
e8m0_
scale
_t
const
scale
,
T
const
data
);
__host__
__device__
inline
float
to_float
(
e8m0_
bexp
_t
const
scale
,
T
const
data
);
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
T
sat_convert_to_type
(
float
value
);
__host__
__device__
T
sat_convert_to_type
(
float
value
);
...
...
include/ck/utility/type_convert.hpp
View file @
48d58131
...
@@ -1000,7 +1000,7 @@ inline __host__ __device__ float type_convert<float, f4_t>(f4_t x)
...
@@ -1000,7 +1000,7 @@ inline __host__ __device__ float type_convert<float, f4_t>(f4_t x)
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
scale
,
0
);
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
scale
,
0
);
return
float_values
.
float_array
[
0
];
return
float_values
.
float_array
[
0
];
#else
#else
return
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
x
);
return
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
x
);
#endif
#endif
}
}
...
@@ -1018,8 +1018,8 @@ inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x)
...
@@ -1018,8 +1018,8 @@ inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x)
float
scale
=
1.0
f
;
float
scale
=
1.0
f
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
scale
,
0
);
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
scale
,
0
);
#else
#else
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
x
.
unpack
(
1
)),
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
x
.
unpack
(
1
)),
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
x
.
unpack
(
0
))};
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
x
.
unpack
(
0
))};
return
ret
;
return
ret
;
#endif
#endif
}
}
...
@@ -1153,72 +1153,72 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
...
@@ -1153,72 +1153,72 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
f4x32_t
f4x32_array
;
f4x32_t
f4x32_array
;
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
// TODO: pack in a loop
// TODO: pack in a loop
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
unpack
(
0
));
f4_values
.
f4x2_array
[
0
].
unpack
(
0
));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
unpack
(
1
));
f4_values
.
f4x2_array
[
0
].
unpack
(
1
));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
unpack
(
0
));
f4_values
.
f4x2_array
[
1
].
unpack
(
0
));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
unpack
(
1
));
f4_values
.
f4x2_array
[
1
].
unpack
(
1
));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
unpack
(
0
));
f4_values
.
f4x2_array
[
2
].
unpack
(
0
));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
unpack
(
1
));
f4_values
.
f4x2_array
[
2
].
unpack
(
1
));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
unpack
(
0
));
f4_values
.
f4x2_array
[
3
].
unpack
(
0
));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
unpack
(
1
));
f4_values
.
f4x2_array
[
3
].
unpack
(
1
));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
unpack
(
0
));
f4_values
.
f4x2_array
[
4
].
unpack
(
0
));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
unpack
(
1
));
f4_values
.
f4x2_array
[
4
].
unpack
(
1
));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
unpack
(
0
));
f4_values
.
f4x2_array
[
5
].
unpack
(
0
));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
unpack
(
1
));
f4_values
.
f4x2_array
[
5
].
unpack
(
1
));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
unpack
(
0
));
f4_values
.
f4x2_array
[
6
].
unpack
(
0
));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
unpack
(
1
));
f4_values
.
f4x2_array
[
6
].
unpack
(
1
));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
unpack
(
0
));
f4_values
.
f4x2_array
[
7
].
unpack
(
0
));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
unpack
(
1
));
f4_values
.
f4x2_array
[
7
].
unpack
(
1
));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
unpack
(
0
));
f4_values
.
f4x2_array
[
8
].
unpack
(
0
));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
unpack
(
1
));
f4_values
.
f4x2_array
[
8
].
unpack
(
1
));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
unpack
(
0
));
f4_values
.
f4x2_array
[
9
].
unpack
(
0
));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
unpack
(
1
));
f4_values
.
f4x2_array
[
9
].
unpack
(
1
));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
unpack
(
0
));
f4_values
.
f4x2_array
[
10
].
unpack
(
0
));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
unpack
(
1
));
f4_values
.
f4x2_array
[
10
].
unpack
(
1
));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
unpack
(
0
));
f4_values
.
f4x2_array
[
11
].
unpack
(
0
));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
unpack
(
1
));
f4_values
.
f4x2_array
[
11
].
unpack
(
1
));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
unpack
(
0
));
f4_values
.
f4x2_array
[
12
].
unpack
(
0
));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
unpack
(
1
));
f4_values
.
f4x2_array
[
12
].
unpack
(
1
));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
unpack
(
0
));
f4_values
.
f4x2_array
[
13
].
unpack
(
0
));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
unpack
(
1
));
f4_values
.
f4x2_array
[
13
].
unpack
(
1
));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
unpack
(
0
));
f4_values
.
f4x2_array
[
14
].
unpack
(
0
));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
unpack
(
1
));
f4_values
.
f4x2_array
[
14
].
unpack
(
1
));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
unpack
(
0
));
f4_values
.
f4x2_array
[
15
].
unpack
(
0
));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
scale
_t
>::
Binary_1
(),
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_
bexp
_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
unpack
(
1
));
f4_values
.
f4x2_array
[
15
].
unpack
(
1
));
return
float_values
.
float32_array
;
return
float_values
.
float32_array
;
...
@@ -1226,24 +1226,24 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
...
@@ -1226,24 +1226,24 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
}
}
template
<
>
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
e8m0_
scale
_t
>
(
e8m0_
scale
_t
scale
)
inline
__host__
__device__
float
type_convert
<
float
,
e8m0_
bexp
_t
>
(
e8m0_
bexp
_t
scale
)
{
{
return
utils
::
cast_to_float
(
scale
);
return
utils
::
cast_to_float
(
scale
);
}
}
template
<
>
template
<
>
inline
__host__
__device__
e8m0_
scale
_t
type_convert
<
e8m0_
scale
_t
,
float
>
(
float
scale
)
inline
__host__
__device__
e8m0_
bexp
_t
type_convert
<
e8m0_
bexp
_t
,
float
>
(
float
scale
)
{
{
return
utils
::
cast_from_float
(
scale
);
return
utils
::
cast_from_float
(
scale
);
}
}
// Declare a template function for scaled conversion
// Declare a template function for scaled conversion
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
scaled_type_convert
(
e8m0_
scale
_t
scale
,
X
x
);
__host__
__device__
constexpr
Y
scaled_type_convert
(
e8m0_
bexp
_t
scale
,
X
x
);
// convert fp4 to fp32
// convert fp4 to fp32
template
<
>
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f4_t
>
(
e8m0_
scale
_t
scale
,
f4_t
x
)
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f4_t
>
(
e8m0_
bexp
_t
scale
,
f4_t
x
)
{
{
#if defined(__gfx950__)
#if defined(__gfx950__)
union
union
...
@@ -1261,7 +1261,7 @@ inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_scale_t s
...
@@ -1261,7 +1261,7 @@ inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_scale_t s
// convert vector of 2 fp4 to vector of 2 fp32
// convert vector of 2 fp4 to vector of 2 fp32
template
<
>
template
<
>
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
f4x2_t
>
(
e8m0_
scale
_t
scale
,
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
f4x2_t
>
(
e8m0_
bexp
_t
scale
,
f4x2_t
x
)
f4x2_t
x
)
{
{
#if defined(__gfx950__)
#if defined(__gfx950__)
...
@@ -1281,7 +1281,7 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_s
...
@@ -1281,7 +1281,7 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_s
// convert vector of 32 fp4 to vector of 32 fp32
// convert vector of 32 fp4 to vector of 32 fp32
template
<
>
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f4x32_t
>
(
e8m0_
scale
_t
scale
,
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f4x32_t
>
(
e8m0_
bexp
_t
scale
,
f4x32_t
x
)
f4x32_t
x
)
{
{
#if defined(__gfx950__)
#if defined(__gfx950__)
...
@@ -1450,7 +1450,7 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
...
@@ -1450,7 +1450,7 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
// convert fp32 to fp4
// convert fp32 to fp4
template
<
>
template
<
>
inline
__host__
__device__
f4_t
scaled_type_convert
<
f4_t
,
float
>
(
e8m0_
scale
_t
scale
,
float
x
)
inline
__host__
__device__
f4_t
scaled_type_convert
<
f4_t
,
float
>
(
e8m0_
bexp
_t
scale
,
float
x
)
{
{
#if CK_USE_SR_F4_CONVERSION
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -1461,7 +1461,7 @@ inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_scale_t sc
...
@@ -1461,7 +1461,7 @@ inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_scale_t sc
// convert vector of 2 fp32 to vector of 2 fp4
// convert vector of 2 fp32 to vector of 2 fp4
template
<
>
template
<
>
inline
__host__
__device__
f4x2_t
scaled_type_convert
<
f4x2_t
,
float2_t
>
(
e8m0_
scale
_t
scale
,
inline
__host__
__device__
f4x2_t
scaled_type_convert
<
f4x2_t
,
float2_t
>
(
e8m0_
bexp
_t
scale
,
float2_t
x
)
float2_t
x
)
{
{
#if CK_USE_SR_F4_CONVERSION
#if CK_USE_SR_F4_CONVERSION
...
@@ -1473,7 +1473,7 @@ inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_sca
...
@@ -1473,7 +1473,7 @@ inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_sca
// convert vector of 32 fp32 to vector of 32 fp4
// convert vector of 32 fp32 to vector of 32 fp4
template
<
>
template
<
>
inline
__host__
__device__
f4x32_t
scaled_type_convert
<
f4x32_t
,
float32_t
>
(
e8m0_
scale
_t
scale
,
inline
__host__
__device__
f4x32_t
scaled_type_convert
<
f4x32_t
,
float32_t
>
(
e8m0_
bexp
_t
scale
,
float32_t
x
)
float32_t
x
)
{
{
#if CK_USE_SR_F4_CONVERSION
#if CK_USE_SR_F4_CONVERSION
...
...
test/data_type/test_fp4.cpp
View file @
48d58131
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
e8m0_
scale
_t
;
using
ck
::
e8m0_
bexp
_t
;
using
ck
::
f4_convert_rne
;
using
ck
::
f4_convert_rne
;
using
ck
::
f4_convert_sr
;
using
ck
::
f4_convert_sr
;
using
ck
::
f4_t
;
using
ck
::
f4_t
;
...
@@ -90,10 +90,10 @@ TEST(FP4, ScaledConvertFP32Nearest)
...
@@ -90,10 +90,10 @@ TEST(FP4, ScaledConvertFP32Nearest)
float
max_fp4
=
6.0
f
;
float
max_fp4
=
6.0
f
;
// set maximum scale
// set maximum scale
float
max_scale
=
std
::
pow
(
2
,
float
max_scale
=
std
::
pow
(
2
,
ck
::
NumericLimits
<
e8m0_
scale
_t
>::
Max
().
data
-
ck
::
NumericLimits
<
e8m0_
bexp
_t
>::
Max
().
data
-
ck
::
NumericUtils
<
e8m0_
scale
_t
>::
bias
);
// 0xFE -> float
ck
::
NumericUtils
<
e8m0_
bexp
_t
>::
bias
);
// 0xFE -> float
// set minimum scale
// set minimum scale
float
min_scale
=
std
::
pow
(
2
,
-
ck
::
NumericUtils
<
e8m0_
scale
_t
>::
bias
);
// 0x00 -> float
float
min_scale
=
std
::
pow
(
2
,
-
ck
::
NumericUtils
<
e8m0_
bexp
_t
>::
bias
);
// 0x00 -> float
// set arbitrary scale to 256.0
// set arbitrary scale to 256.0
float
test_scale
=
256.0
f
;
// 0b10000111
float
test_scale
=
256.0
f
;
// 0b10000111
// convert 0 float to fp4 and back with maximal scale, check if holds
// convert 0 float to fp4 and back with maximal scale, check if holds
...
@@ -162,10 +162,10 @@ TEST(FP4, ScaledConvertFP32Stochastic)
...
@@ -162,10 +162,10 @@ TEST(FP4, ScaledConvertFP32Stochastic)
float
max_fp4
=
6.0
f
;
float
max_fp4
=
6.0
f
;
// set maximum scale
// set maximum scale
float
max_scale
=
std
::
pow
(
2
,
float
max_scale
=
std
::
pow
(
2
,
ck
::
NumericLimits
<
e8m0_
scale
_t
>::
Max
().
data
-
ck
::
NumericLimits
<
e8m0_
bexp
_t
>::
Max
().
data
-
ck
::
NumericUtils
<
e8m0_
scale
_t
>::
bias
);
// 0xFE -> float
ck
::
NumericUtils
<
e8m0_
bexp
_t
>::
bias
);
// 0xFE -> float
// set minimum scale
// set minimum scale
float
min_scale
=
std
::
pow
(
2
,
-
ck
::
NumericUtils
<
e8m0_
scale
_t
>::
bias
);
// 0x00 -> float
float
min_scale
=
std
::
pow
(
2
,
-
ck
::
NumericUtils
<
e8m0_
bexp
_t
>::
bias
);
// 0x00 -> float
// set arbitrary scale to 256.0
// set arbitrary scale to 256.0
float
test_scale
=
256.0
f
;
// 0b10000111
float
test_scale
=
256.0
f
;
// 0b10000111
// convert 0 float to fp4 and back with maximal scale, check if holds
// convert 0 float to fp4 and back with maximal scale, check if holds
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment