Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
cde0f480
Unverified
Commit
cde0f480
authored
Dec 03, 2024
by
Rostyslav Geyyer
Committed by
GitHub
Dec 03, 2024
Browse files
Merge pull request #200 from ROCm/lwpck-2390
Enable MXFP4 type
parents
b7566434
773c0e70
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
2328 additions
and
29 deletions
+2328
-29
include/ck/ck.hpp
include/ck/ck.hpp
+3
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+149
-5
include/ck/utility/e8m0_utils.hpp
include/ck/utility/e8m0_utils.hpp
+33
-0
include/ck/utility/mxf4_utils.hpp
include/ck/utility/mxf4_utils.hpp
+109
-0
include/ck/utility/mxfp_utils.hpp
include/ck/utility/mxfp_utils.hpp
+384
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+1049
-0
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+75
-24
library/include/ck/library/utility/host_tensor_generator.hpp
library/include/ck/library/utility/host_tensor_generator.hpp
+43
-0
test/data_type/CMakeLists.txt
test/data_type/CMakeLists.txt
+4
-0
test/data_type/test_fp4.cpp
test/data_type/test_fp4.cpp
+479
-0
No files found.
include/ck/ck.hpp
View file @
cde0f480
...
@@ -158,6 +158,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -158,6 +158,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// set rounding to nearest even as default for f8 conversions
// set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0
#define CK_USE_SR_F8_CONVERSION 0
// set rounding to nearest even as default for f4 conversions
#define CK_USE_SR_F4_CONVERSION 0
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
...
...
include/ck/utility/data_type.hpp
View file @
cde0f480
...
@@ -11,6 +11,40 @@ namespace ck {
...
@@ -11,6 +11,40 @@ namespace ck {
using
bhalf_t
=
ushort
;
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
half_t
=
_Float16
;
using
int4_t
=
_BitInt
(
4
);
using
int4_t
=
_BitInt
(
4
);
using
f4_t
=
unsigned
_BitInt
(
4
);
struct
e8m0_bexp_t
{
// E8M0 scale is biased
using
type
=
uint8_t
;
type
data
;
constexpr
e8m0_bexp_t
()
:
data
{
type
{}}
{}
constexpr
e8m0_bexp_t
(
type
init
)
:
data
{
init
}
{}
bool
operator
==
(
const
e8m0_bexp_t
&
other
)
const
{
return
(
data
==
other
.
data
);
}
};
struct
f4x2_pk_t
{
using
type
=
uint8_t
;
type
data
;
f4x2_pk_t
()
:
data
{
type
{}}
{}
f4x2_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
type
unpack
()
const
{
if
constexpr
(
I
==
0
)
return
data
&
0b00001111
;
else
return
(
data
>>
4
);
}
__host__
__device__
inline
type
pack
(
const
type
x0
,
const
type
x1
)
{
return
(
x1
<<
4
)
|
(
x0
&
0b00001111
);
}
};
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
{
{
...
@@ -26,7 +60,7 @@ inline constexpr bool is_native_type()
...
@@ -26,7 +60,7 @@ inline constexpr bool is_native_type()
return
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
half_t
>::
value
||
return
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
bhalf_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
bhalf_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
uint8_t
>::
value
||
is_same
<
T
,
f8_fnuz_t
>::
value
||
is_same
<
T
,
uint8_t
>::
value
||
is_same
<
T
,
f8_fnuz_t
>::
value
||
is_same
<
T
,
bf8_fnuz_t
>::
value
||
is_same
<
T
,
bool
>::
value
;
is_same
<
T
,
bf8_fnuz_t
>::
value
||
is_same
<
T
,
bool
>::
value
||
is_same
<
T
,
f4_t
>::
value
;
}
}
// vector_type
// vector_type
...
@@ -1871,6 +1905,14 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type;
...
@@ -1871,6 +1905,14 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
// f4
using
f4x2_t
=
typename
vector_type
<
f4x2_pk_t
,
1
>::
type
;
using
f4x4_t
=
typename
vector_type
<
f4x2_pk_t
,
2
>::
type
;
using
f4x8_t
=
typename
vector_type
<
f4x2_pk_t
,
4
>::
type
;
using
f4x16_t
=
typename
vector_type
<
f4x2_pk_t
,
8
>::
type
;
using
f4x32_t
=
typename
vector_type
<
f4x2_pk_t
,
16
>::
type
;
using
f4x64_t
=
typename
vector_type
<
f4x2_pk_t
,
32
>::
type
;
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
{
{
...
@@ -2009,6 +2051,59 @@ struct NumericLimits<bf8_ocp_t>
...
@@ -2009,6 +2051,59 @@ struct NumericLimits<bf8_ocp_t>
}
}
};
};
template
<
>
struct
NumericLimits
<
f4_t
>
{
static
constexpr
uint8_t
binary_min_normal
=
0x2
;
// 0b0010
static
constexpr
uint8_t
binary_max_normal
=
0x7
;
// 0b0111
static
constexpr
uint8_t
binary_lowest_normal
=
0xF
;
// 0b1111
static
constexpr
uint8_t
binary_min_subnorm
=
0x1
;
// 0b0001
static
constexpr
uint8_t
binary_max_subnorm
=
0x1
;
// 0b0001
static
constexpr
float
data_max_normal_number
=
6
;
static
constexpr
float
data_min_subnormal_number
=
0.5
;
__host__
__device__
static
constexpr
f4_t
Min
()
{
return
f4_t
(
binary_min_normal
);
}
__host__
__device__
static
constexpr
f4_t
Max
()
{
return
f4_t
(
binary_max_normal
);
}
__host__
__device__
static
constexpr
f4_t
Lowest
()
{
return
f4_t
(
binary_lowest_normal
);
}
__host__
__device__
static
constexpr
f4_t
MinSubnorm
()
{
return
f4_t
(
binary_min_subnorm
);
}
__host__
__device__
static
constexpr
f4_t
MaxSubnorm
()
{
return
f4_t
(
binary_max_subnorm
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
template
<
>
struct
NumericLimits
<
e8m0_bexp_t
>
{
static
constexpr
e8m0_bexp_t
binary_min
=
0x00
;
// 0b00000000
static
constexpr
e8m0_bexp_t
binary_max
=
0xFE
;
// 0b11111110
static
constexpr
e8m0_bexp_t
binary_qnan
=
0xFF
;
// 0b11111111
static
constexpr
e8m0_bexp_t
binary_1
=
0x7F
;
// 0b01111111
static
constexpr
e8m0_bexp_t
binary_2
=
0x80
;
// 0b10000000
static
constexpr
e8m0_bexp_t
binary_3
=
0x82
;
// 0b10000010
static
constexpr
e8m0_bexp_t
binary_135
=
0x87
;
// 0b10000111
static
constexpr
e8m0_bexp_t
binary_142
=
0x8E
;
// 0b10001110
__host__
__device__
static
constexpr
e8m0_bexp_t
Min
()
{
return
e8m0_bexp_t
(
binary_min
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Max
()
{
return
e8m0_bexp_t
(
binary_max
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
QuietNaN
()
{
return
e8m0_bexp_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_1
()
{
return
e8m0_bexp_t
(
binary_1
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_2
()
{
return
e8m0_bexp_t
(
binary_2
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_3
()
{
return
e8m0_bexp_t
(
binary_3
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_135
()
{
return
e8m0_bexp_t
(
binary_135
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_142
()
{
return
e8m0_bexp_t
(
binary_142
);
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
NumericUtils
struct
NumericUtils
{
{
...
@@ -2028,6 +2123,7 @@ struct NumericUtils<float>
...
@@ -2028,6 +2123,7 @@ struct NumericUtils<float>
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
Neg0
=
0x80000000
;
static
constexpr
uint32_t
Neg0
=
0x80000000
;
static
constexpr
bool
has_inf
=
true
;
using
bitwise_type
=
uint32_t
;
using
bitwise_type
=
uint32_t
;
};
};
...
@@ -2045,9 +2141,19 @@ struct NumericUtils<half_t>
...
@@ -2045,9 +2141,19 @@ struct NumericUtils<half_t>
static
constexpr
uint32_t
NegInf
=
0xFC00
;
static
constexpr
uint32_t
NegInf
=
0xFC00
;
static
constexpr
uint32_t
NaN
=
0x7C01
;
static
constexpr
uint32_t
NaN
=
0x7C01
;
static
constexpr
uint32_t
Neg0
=
0x8000
;
static
constexpr
uint32_t
Neg0
=
0x8000
;
static
constexpr
bool
has_inf
=
true
;
using
bitwise_type
=
uint16_t
;
using
bitwise_type
=
uint16_t
;
};
};
template
<
>
struct
NumericUtils
<
bhalf_t
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
static
constexpr
int
bias
=
128
;
// negative zero nan mode
// static constexpr int bias = 127; // ieee mode
};
template
<
>
template
<
>
struct
NumericUtils
<
f8_fnuz_t
>
struct
NumericUtils
<
f8_fnuz_t
>
{
{
...
@@ -2055,6 +2161,7 @@ struct NumericUtils<f8_fnuz_t>
...
@@ -2055,6 +2161,7 @@ struct NumericUtils<f8_fnuz_t>
static
constexpr
int
mant
=
3
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
8
;
// negative zero nan mode
static
constexpr
int
bias
=
8
;
// negative zero nan mode
// static constexpr int bias = 7; // ieee mode
// static constexpr int bias = 7; // ieee mode
static
constexpr
bool
has_inf
=
false
;
};
};
template
<
>
template
<
>
...
@@ -2064,6 +2171,7 @@ struct NumericUtils<bf8_fnuz_t>
...
@@ -2064,6 +2171,7 @@ struct NumericUtils<bf8_fnuz_t>
static
constexpr
int
mant
=
2
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
16
;
// negative zero nan mode
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
// static constexpr int bias = 15; // ieee mode
static
constexpr
bool
has_inf
=
false
;
};
};
template
<
>
template
<
>
struct
NumericUtils
<
f8_ocp_t
>
struct
NumericUtils
<
f8_ocp_t
>
...
@@ -2082,11 +2190,47 @@ struct NumericUtils<bf8_ocp_t>
...
@@ -2082,11 +2190,47 @@ struct NumericUtils<bf8_ocp_t>
};
};
template
<
>
template
<
>
struct
NumericUtils
<
bhalf_t
>
struct
NumericUtils
<
f4_t
>
{
static
constexpr
int
exp
=
2
;
static
constexpr
int
mant
=
1
;
static
constexpr
int
bias
=
1
;
static
constexpr
uint32_t
sr_shift
=
10
;
static
constexpr
int
unbiased_exp_min
=
0
;
static
constexpr
int
unbiased_exp_max
=
2
;
static
constexpr
int
biased_exp_min
=
1
;
static
constexpr
int
biased_exp_max
=
3
;
static
constexpr
uint8_t
positive_zero_mask
=
0b0000
;
static
constexpr
uint8_t
negative_zero_mask
=
0b1000
;
static
constexpr
uint8_t
one_mask
=
0b0010
;
static
constexpr
uint8_t
set_sign_mask
=
0b0111
;
static
constexpr
uint8_t
data_max_positive_normal_mask
=
0b0111
;
static
constexpr
uint8_t
data_max_negative_normal_mask
=
0b1111
;
static
constexpr
uint8_t
data_max_positive_subnormal_mask
=
0b0001
;
static
constexpr
uint8_t
data_max_negative_subnormal_mask
=
0b1001
;
static
constexpr
bool
has_inf
=
false
;
using
bitwise_type
=
uint8_t
;
};
template
<
>
struct
NumericUtils
<
e8m0_bexp_t
>
{
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
static
constexpr
int
mant
=
0
;
static
constexpr
int
bias
=
128
;
// negative zero nan mode
static
constexpr
int
bias
=
127
;
// static constexpr int bias = 127; // ieee mode
static
constexpr
int
unbiased_exp_min
=
-
127
;
static
constexpr
int
unbiased_exp_max
=
127
;
static
constexpr
int
biased_exp_min
=
0
;
static
constexpr
int
biased_exp_max
=
254
;
using
bitwise_type
=
uint8_t
;
};
};
}
// namespace ck
}
// namespace ck
include/ck/utility/e8m0_utils.hpp
0 → 100644
View file @
cde0f480
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace
ck
::
utils
{
__host__
__device__
inline
float
cast_to_float
(
e8m0_bexp_t
const
bexp
)
{
// TODO: check performance and try bit shift impl
return
std
::
powf
(
2
,
bit_cast
<
uint8_t
>
(
bexp
)
-
NumericUtils
<
e8m0_bexp_t
>::
bias
);
}
__host__
__device__
inline
e8m0_bexp_t
cast_from_float
(
float
const
scale
)
{
uint32_t
e
=
bit_cast
<
uint32_t
>
(
scale
)
&
NumericUtils
<
float
>::
nan_mask
;
return
static_cast
<
uint8_t
>
(
e
>>
23
);
}
template
<
>
__host__
__device__
inline
int
get_exponent_value
<
e8m0_bexp_t
>
(
e8m0_bexp_t
x
)
{
x
.
data
>>=
NumericUtils
<
e8m0_bexp_t
>::
mant
;
x
.
data
&=
((
1
<<
NumericUtils
<
e8m0_bexp_t
>::
exp
)
-
1
);
return
static_cast
<
int
>
(
x
.
data
);
}
}
// namespace ck::utils
include/ck/utility/mxf4_utils.hpp
0 → 100644
View file @
cde0f480
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace
ck
::
utils
{
template
<
>
__host__
__device__
inline
bool
is_nan
<
f4_t
>
(
e8m0_bexp_t
const
scale
,
f4_t
const
dataBytes
[[
maybe_unused
]])
{
// no need to check for data as it does not have NaN representation
return
scale
==
NumericLimits
<
e8m0_bexp_t
>::
QuietNaN
();
}
// no infinity representation in ocp_e2m1_mxfp4 will always return false
template
<
>
__host__
__device__
inline
bool
is_inf
<
f4_t
>
(
e8m0_bexp_t
const
scale
[[
maybe_unused
]],
f4_t
const
data
[[
maybe_unused
]])
{
// no inf representation for ocp_e2m1_mxfp4
return
false
;
}
template
<
>
__host__
__device__
inline
bool
is_zero
<
f4_t
>
(
e8m0_bexp_t
const
scale
,
f4_t
const
data
)
{
if
(
is_nan
<
f4_t
>
(
scale
,
data
))
return
false
;
// no need to check for scale as it does not have a 0 representation
f4_t
result
=
(
data
&
0b00001111
)
&
NumericUtils
<
f4_t
>::
set_sign_mask
;
return
result
==
0b0
;
}
template
<
>
__host__
__device__
inline
float
to_float
<
f4_t
>
(
e8m0_bexp_t
const
scale
,
f4_t
const
data
)
{
if
(
is_nan
<
f4_t
>
(
scale
,
data
))
return
std
::
numeric_limits
<
float
>::
quiet_NaN
();
if
(
is_zero
<
f4_t
>
(
scale
,
data
))
return
0.0
f
;
f4_t
prepared_data
=
data
&
0b00001111
;
int
scale_exp
=
get_exponent_value
<
e8m0_bexp_t
>
(
scale
);
return
convert_to_float
<
f4_t
>
(
prepared_data
,
scale_exp
);
}
template
<
>
__host__
__device__
inline
f4_t
sat_convert_to_type
<
f4_t
>
(
float
value
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
{
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
}
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f4_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
f4_t
res
=
convert_to_type
<
f4_t
>
(
value
);
if
(
std
::
abs
(
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f4_t
>::
DataMinSubnorm
())
return
value
<
0
?
NumericUtils
<
f4_t
>::
negative_zero_mask
:
NumericUtils
<
f4_t
>::
positive_zero_mask
;
return
res
;
}
template
<
>
__host__
__device__
inline
f4_t
sat_convert_to_type_sr
<
f4_t
>
(
float
value
,
uint32_t
seed
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f4_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
f4_t
res
=
convert_to_type_sr
<
f4_t
>
(
value
,
seed
);
if
(
std
::
abs
(
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f4_t
>::
DataMinSubnorm
())
return
value
<
0
?
NumericUtils
<
f4_t
>::
negative_zero_mask
:
NumericUtils
<
f4_t
>::
positive_zero_mask
;
return
res
;
}
}
// namespace ck::utils
include/ck/utility/mxfp_utils.hpp
0 → 100644
View file @
cde0f480
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
::
utils
{
union
cvt
{
float
value_float
;
uint32_t
value_bitwise
;
};
template
<
typename
DTYPE
>
inline
bool
getDataHasInf
()
{
return
DTYPE
::
dataInfo
.
hasInf
;
}
template
<
typename
T
>
__host__
__device__
inline
bool
is_zero
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
bool
is_nan
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
bool
is_inf
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
int
get_exponent_value
(
T
x
)
{
x
>>=
NumericUtils
<
T
>::
mant
;
x
&=
((
1
<<
NumericUtils
<
T
>::
exp
)
-
1
);
return
static_cast
<
int
>
(
x
);
}
template
<
typename
T
>
__host__
__device__
inline
bool
is_subnormal
(
T
x
)
{
return
get_exponent_value
<
T
>
(
x
)
==
0
;
}
template
<
typename
T
>
__host__
__device__
inline
double
get_mantissa_value
(
T
x
)
{
double
mantissa
=
is_subnormal
<
T
>
(
x
)
?
0.0
f
:
1.0
f
;
for
(
uint
i
=
0
;
i
<
NumericUtils
<
T
>::
mant
;
i
++
)
{
mantissa
+=
std
::
pow
(
2
,
-
int32_t
((
NumericUtils
<
T
>::
mant
-
i
)))
*
(
x
&
0b1
);
x
>>=
1
;
}
return
mantissa
;
}
template
<
typename
T
>
__host__
__device__
inline
bool
get_data_has_inf
()
{
return
NumericUtils
<
T
>::
has_inf
;
}
template
<
typename
T
>
__host__
__device__
float
convert_to_float
(
T
data
,
int
scale_exp
)
{
float
d_sign
=
std
::
pow
(
-
1
,
static_cast
<
float
>
(
data
>>
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
)));
float
d_exp
;
if
(
is_subnormal
<
T
>
(
data
))
d_exp
=
std
::
pow
(
2
,
1
-
static_cast
<
int
>
(
NumericUtils
<
T
>::
bias
));
else
d_exp
=
std
::
pow
(
2
,
get_exponent_value
<
T
>
(
data
)
-
static_cast
<
int
>
(
NumericUtils
<
T
>::
bias
));
float
d_mant
=
get_mantissa_value
<
T
>
(
data
);
float
data_value
=
d_sign
*
d_exp
*
d_mant
;
float
scale_value
=
std
::
pow
(
2
,
static_cast
<
float
>
((
scale_exp
-
static_cast
<
int
>
(
NumericUtils
<
e8m0_bexp_t
>::
bias
))));
return
data_value
*
scale_value
;
}
template
<
typename
T
>
__host__
__device__
inline
float
to_float
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
T
sat_convert_to_type
(
float
value
);
template
<
typename
T
>
__host__
__device__
T
sat_convert_to_type_sr
(
float
value
,
uint32_t
seed
);
template
<
typename
T
>
inline
T
convert_to_type
(
float
value
)
{
using
bitwise_type
=
typename
NumericUtils
<
T
>::
bitwise_type
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
T
>::
Max
())
{
float
max_value
=
NumericLimits
<
T
>::
Max
();
cvt
t
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
max_value
;
uint32_t
max_bitwise
=
t
.
value_bitwise
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
value
;
bitwise_type
sign
=
t
.
value_bitwise
>>
(
NumericUtils
<
float
>::
exp
+
NumericUtils
<
float
>::
mant
);
bitwise_type
exp
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
)
-
(
NumericUtils
<
float
>::
bias
-
NumericUtils
<
T
>::
bias
);
bitwise_type
mantissa
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
mant_prev
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant_prev
&=
((
1
<<
NumericUtils
<
T
>::
mant
)
-
1
);
mant_prev
--
;
mant_prev
<<=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
prev_bit
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
<<
NumericUtils
<
float
>::
mant
)
|
mant_prev
;
t
.
value_bitwise
=
prev_bit
;
float
prev_val
=
t
.
value_float
;
float
diff
=
max_value
-
prev_val
;
float
actual_max
=
max_value
+
(
diff
/
2
);
if
(
std
::
abs
(
value
)
<
actual_max
)
{
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
exp
<<
NumericUtils
<
T
>::
mant
)
|
mantissa
;
}
else
{
if
(
!
get_data_has_inf
<
T
>
())
{
return
(
1
<<
(
NumericUtils
<
T
>::
mant
+
NumericUtils
<
T
>::
exp
))
-
1
;
}
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
}
const
int
mfmt
=
NumericUtils
<
float
>::
mant
;
uint32_t
x
;
x
=
bit_cast
<
uint32_t
>
(
value
);
uint32_t
head
,
mantissa
;
int32_t
exponent
,
bias
;
uint32_t
sign
;
head
=
x
&
NumericUtils
<
float
>::
head_mask
;
mantissa
=
x
&
NumericUtils
<
float
>::
mant_mask
;
exponent
=
(
head
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
;
sign
=
head
>>
(
NumericUtils
<
float
>::
mant
+
NumericUtils
<
float
>::
exp
);
bias
=
NumericUtils
<
float
>::
bias
;
if
(
x
==
0
)
{
return
0b0
;
}
const
int
mini_bias
=
NumericUtils
<
T
>::
bias
;
const
int
mini_denormal_act_exponent
=
1
-
mini_bias
;
int
act_exponent
,
out_exponent
,
exponent_diff
;
bool
is_subnorm
=
false
;
if
(
exponent
==
0
)
{
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
mini_denormal_act_exponent
-
act_exponent
;
is_subnorm
=
true
;
}
else
{
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
mini_denormal_act_exponent
)
{
exponent_diff
=
mini_denormal_act_exponent
-
act_exponent
;
is_subnorm
=
true
;
}
else
{
exponent_diff
=
0
;
}
mantissa
+=
(
1UL
<<
mfmt
);
}
auto
shift_amount
=
(
mfmt
-
NumericUtils
<
T
>::
mant
+
exponent_diff
);
shift_amount
=
(
shift_amount
>=
64
)
?
63
:
shift_amount
;
bool
midpoint
=
(
mantissa
&
((
1UL
<<
shift_amount
)
-
1
))
==
(
1UL
<<
(
shift_amount
-
1
));
float
min_subnorm
=
NumericLimits
<
T
>::
DataMinSubnorm
()
*
(
sign
?
-
1
:
1
);
if
(
is_subnorm
&&
std
::
abs
(
value
)
<
std
::
abs
(
min_subnorm
))
{
// closer to 0
if
(
std
::
abs
(
value
)
<=
std
::
abs
(
min_subnorm
-
value
))
return
0
;
else
return
1
|
(
sign
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
));
}
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
mfmt
);
out_exponent
=
(
act_exponent
+
exponent_diff
)
+
mini_bias
-
(
implicit_one
?
0
:
1
);
uint32_t
drop_mask
=
(
1UL
<<
(
mfmt
-
NumericUtils
<
T
>::
mant
))
-
1
;
bool
odd
=
mantissa
&
(
1UL
<<
(
mfmt
-
NumericUtils
<
T
>::
mant
));
mantissa
+=
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
)
&
drop_mask
;
if
(
out_exponent
==
0
)
{
if
((
1UL
<<
mfmt
)
&
mantissa
)
{
out_exponent
=
1
;
}
}
else
{
if
((
1UL
<<
(
mfmt
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
out_exponent
++
;
}
}
mantissa
>>=
(
mfmt
-
NumericUtils
<
T
>::
mant
);
if
(
out_exponent
==
0
&&
mantissa
==
0
)
{
return
0
;
}
mantissa
&=
(
1UL
<<
NumericUtils
<
T
>::
mant
)
-
1
;
return
(
sign
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
out_exponent
<<
NumericUtils
<
T
>::
mant
)
|
mantissa
;
}
template
<
typename
T
>
inline
T
convert_to_type_sr
(
float
value
,
uint32_t
seed
)
{
if
(
std
::
abs
(
value
)
>
NumericLimits
<
T
>::
Max
())
{
float
max_value
=
NumericLimits
<
T
>::
Max
();
cvt
t
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
max_value
;
uint
max_bitwise
=
t
.
value_bitwise
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
value
;
T
sign
=
t
.
value_bitwise
>>
(
NumericUtils
<
float
>::
exp
+
NumericUtils
<
float
>::
mant
);
T
exp
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
)
-
(
NumericUtils
<
float
>::
bias
-
NumericUtils
<
T
>::
bias
);
uint32_t
mant_prev
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant_prev
&=
((
1UL
<<
NumericUtils
<
T
>::
mant
)
-
1
);
mant_prev
--
;
mant_prev
<<=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
prev_bit
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
<<
NumericUtils
<
float
>::
mant
)
|
mant_prev
;
t
.
value_bitwise
=
prev_bit
;
float
prev_val
=
t
.
value_float
;
float
diff
=
max_value
-
prev_val
;
float
actual_max
=
max_value
+
(
diff
/
2
);
if
(
std
::
abs
(
value
)
<
actual_max
)
{
double
d_max_value
=
static_cast
<
double
>
(
max_value
);
double
d_actual_max
=
static_cast
<
double
>
(
actual_max
);
double
d_value
=
static_cast
<
double
>
(
value
);
double
d_is
=
std
::
abs
(
d_max_value
-
d_actual_max
);
double
d_seed
=
static_cast
<
double
>
(
seed
);
double
d_prob
=
1.0
f
-
(
std
::
abs
(
d_value
-
d_max_value
)
/
d_is
);
// prob to round down
double
thresh
=
UINT_MAX
*
d_prob
;
if
(
!
get_data_has_inf
<
T
>
()
||
d_seed
<=
thresh
)
// return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time
return
sign
==
0
?
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
;
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
// inf
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
else
{
if
(
!
get_data_has_inf
<
T
>
())
return
(
1
<<
(
NumericUtils
<
T
>::
mant
+
NumericUtils
<
T
>::
exp
))
-
1
;
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
// inf
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
}
uint32_t
f32
=
bit_cast
<
uint32_t
>
(
value
);
auto
f32_mant
=
f32
&
NumericUtils
<
float
>::
mant_mask
;
auto
head
=
f32
&
NumericUtils
<
float
>::
head_mask
;
auto
f32_exp
=
(
head
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
;
auto
sign_bit
=
head
>>
(
NumericUtils
<
float
>::
mant
+
NumericUtils
<
float
>::
exp
);
auto
sign
=
sign_bit
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
);
f32_exp
=
static_cast
<
int32_t
>
(
f32_exp
)
-
NumericUtils
<
float
>::
bias
;
int32_t
exp
=
f32_exp
;
auto
mant
=
f32_mant
;
bool
subnorm
=
false
;
if
(
f32
==
0
)
return
0b0
;
if
(
exp
>=
NumericUtils
<
T
>::
unbiased_exp_min
)
{
mant
=
f32_mant
;
}
// if the exponent bit is 8, then the subnormal is exactly the same as f32
else
if
(
exp
<
NumericUtils
<
T
>::
unbiased_exp_min
&&
NumericUtils
<
T
>::
exp
<
NumericUtils
<
float
>::
exp
)
{
subnorm
=
true
;
auto
diff
=
static_cast
<
uint32_t
>
(
NumericUtils
<
T
>::
unbiased_exp_min
-
exp
);
if
(
diff
>=
32
)
{
mant
=
0
;
f32_mant
=
0
;
}
else
{
f32_mant
|=
static_cast
<
uint32_t
>
(
1
)
<<
NumericUtils
<
float
>::
mant
;
f32_mant
>>=
diff
;
}
exp
=
0
;
mant
=
f32_mant
;
}
uint32_t
sr_shift
=
NumericUtils
<
T
>::
sr_shift
;
// For stochastic-rounding we add the aligned random value to the
// mantissa and then truncate (RTZ).
mant
+=
seed
>>
sr_shift
;
// Increment exponent when mantissa overflows due to rounding
if
(
mant
>=
static_cast
<
uint32_t
>
(
1
)
<<
NumericUtils
<
float
>::
mant
)
++
exp
;
mant
>>=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant
&=
((
1
<<
NumericUtils
<
T
>::
mant
)
-
1
);
auto
biased_exp
=
static_cast
<
uint32_t
>
(
exp
);
if
(
!
subnorm
)
biased_exp
=
static_cast
<
uint32_t
>
(
exp
+
NumericUtils
<
T
>::
bias
);
biased_exp
&=
((
1
<<
NumericUtils
<
T
>::
exp
)
-
1
);
auto
val
=
sign
|
biased_exp
<<
NumericUtils
<
T
>::
mant
|
mant
;
return
val
;
}
}
// namespace ck::utils
include/ck/utility/type_convert.hpp
View file @
cde0f480
...
@@ -4,7 +4,9 @@
...
@@ -4,7 +4,9 @@
#pragma once
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/e8m0_utils.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/mxf4_utils.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/array.hpp"
...
@@ -583,6 +585,1053 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
...
@@ -583,6 +585,1053 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
#endif
#endif
}
}
// convert fp32 to fp4 with rounding to nearest even
inline
__host__
__device__
f4_t
f4_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4_t
f4_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
value
.
bitwise
,
x
,
x
,
scale
,
0
);
return
value
.
f4_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
f4_t
>
(
x
/
scale
);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with rne
inline
__host__
__device__
f4x2_t
f4_convert_rne
(
float2_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
value
.
bitwise
,
x
[
0
],
x
[
1
],
scale
,
0
);
return
value
.
f4x2_array
[
0
];
#else
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
uint8_t
l
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
1
]
/
scale
);
uint8_t
h
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
0
]
/
scale
);
value
.
bitwise
=
(
h
<<
4
)
|
l
;
return
value
.
f4x2_array
[
0
];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with rne
inline
__host__
__device__
f4x32_t
f4_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{},
tmp_values
{};
// TODO: pack in a loop
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
0
],
x
[
1
],
scale
,
0
);
f4_values
.
f4x2_array
[
0
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
2
],
x
[
3
],
scale
,
0
);
f4_values
.
f4x2_array
[
1
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
4
],
x
[
5
],
scale
,
0
);
f4_values
.
f4x2_array
[
2
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
6
],
x
[
7
],
scale
,
0
);
f4_values
.
f4x2_array
[
3
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
8
],
x
[
9
],
scale
,
0
);
f4_values
.
f4x2_array
[
4
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
10
],
x
[
11
],
scale
,
0
);
f4_values
.
f4x2_array
[
5
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
12
],
x
[
13
],
scale
,
0
);
f4_values
.
f4x2_array
[
6
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
14
],
x
[
15
],
scale
,
0
);
f4_values
.
f4x2_array
[
7
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
16
],
x
[
17
],
scale
,
0
);
f4_values
.
f4x2_array
[
8
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
18
],
x
[
19
],
scale
,
0
);
f4_values
.
f4x2_array
[
9
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
20
],
x
[
21
],
scale
,
0
);
f4_values
.
f4x2_array
[
10
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
22
],
x
[
23
],
scale
,
0
);
f4_values
.
f4x2_array
[
11
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
24
],
x
[
25
],
scale
,
0
);
f4_values
.
f4x2_array
[
12
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
26
],
x
[
27
],
scale
,
0
);
f4_values
.
f4x2_array
[
13
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
28
],
x
[
29
],
scale
,
0
);
f4_values
.
f4x2_array
[
14
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
30
],
x
[
31
],
scale
,
0
);
f4_values
.
f4x2_array
[
15
]
=
tmp_values
.
f4x2_array
[
0
];
return
f4_values
.
f4x32_array
;
#else
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{};
// TODO: pack in a loop
auto
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
0
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
1
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
2
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
3
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
4
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
5
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
6
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
7
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
8
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
9
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
10
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
11
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
12
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
13
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
14
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
15
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
16
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
17
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
18
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
19
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
20
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
21
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
22
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
23
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
24
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
25
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
26
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
27
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
28
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
29
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
30
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
31
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
return
f4_values
.
f4x32_array
;
#endif
}
// convert fp32 to fp4 with stochastic rounding
inline
__host__
__device__
f4_t
f4_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4_t
f4_array
[
4
];
}
value
{
0
};
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{{
x
}};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
float2_array
,
rng
,
scale
,
0
);
return
value
.
f4_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
/
scale
,
rng
);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with sr
inline
__host__
__device__
f4x2_t
f4_convert_sr
(
float2_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
[
0
]);
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
x
,
rng
,
scale
,
0
);
return
value
.
f4x2_array
[
0
];
#else
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
uint8_t
l
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
1
]
/
scale
,
rng
);
uint8_t
h
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
0
]
/
scale
,
rng
);
value
.
bitwise
=
(
h
<<
4
)
|
l
;
return
value
.
f4x2_array
[
0
];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with sr
inline
__host__
__device__
f4x32_t
f4_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
[
0
]);
#if defined(__gfx950__)
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
0
},
tmp_values
{
0
};
union
{
float2_t
floatx2_array
[
16
];
float32_t
floatx32_array
;
}
float_values
{{
0
}};
// TODO: pack in a loop
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
0
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
0
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
1
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
1
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
2
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
2
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
3
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
3
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
4
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
4
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
5
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
5
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
6
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
6
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
7
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
7
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
8
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
8
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
9
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
9
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
10
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
10
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
11
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
11
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
12
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
12
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
13
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
13
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
14
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
14
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
15
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
15
]
=
tmp_values
.
f4x2_array
[
0
];
return
f4_values
.
f4x32_array
;
#else
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
0
};
// TODO: pack in a loop
auto
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
0
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
1
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
2
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
3
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
4
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
5
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
6
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
7
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
8
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
9
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
10
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
11
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
12
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
13
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
14
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
15
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
16
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
17
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
18
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
19
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
20
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
21
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
22
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
23
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
24
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
25
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
26
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
27
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
28
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
29
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
30
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
31
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
return
f4_values
.
f4x32_array
;
#endif
}
// convert fp32 to fp4
template
<
>
inline
__host__
__device__
f4_t
type_convert
<
f4_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template
<
>
inline
__host__
__device__
f4x2_t
type_convert
<
f4x2_t
,
float2_t
>
(
float2_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template
<
>
inline
__host__
__device__
f4x32_t
type_convert
<
f4x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert fp4 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f4_t
>
(
f4_t
x
)
{
#if defined(__gfx950__)
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{};
float
scale
=
1.0
f
;
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
scale
,
0
);
return
float_values
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f4x2_t
>
(
f4x2_t
x
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{};
value
.
f4x2_array
[
0
]
=
x
;
float
scale
=
1.0
f
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
scale
,
0
);
#else
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
()),
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
())};
return
ret
;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
f4x32_t
>
(
f4x32_t
x
)
{
#if defined(__gfx950__)
union
{
f4x32_t
f4x32_array
;
f4x2_t
fp4x2
[
16
];
}
value
{
x
};
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
bitwise_value
{};
float2_t
op
;
float32_t
ret
;
float
scale
=
1.0
f
;
// TODO: pack in a loop
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
0
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
0
]
=
op
[
0
];
ret
[
1
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
2
]
=
op
[
0
];
ret
[
3
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
2
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
4
]
=
op
[
0
];
ret
[
5
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
3
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
6
]
=
op
[
0
];
ret
[
7
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
4
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
8
]
=
op
[
0
];
ret
[
9
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
5
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
10
]
=
op
[
0
];
ret
[
11
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
6
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
12
]
=
op
[
0
];
ret
[
13
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
7
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
14
]
=
op
[
0
];
ret
[
15
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
8
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
16
]
=
op
[
0
];
ret
[
17
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
9
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
18
]
=
op
[
0
];
ret
[
19
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
10
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
20
]
=
op
[
0
];
ret
[
21
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
11
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
22
]
=
op
[
0
];
ret
[
23
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
12
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
24
]
=
op
[
0
];
ret
[
25
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
13
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
26
]
=
op
[
0
];
ret
[
27
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
14
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
28
]
=
op
[
0
];
ret
[
29
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
15
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
30
]
=
op
[
0
];
ret
[
31
]
=
op
[
1
];
return
ret
;
#else
union
{
float32_t
float32_array
;
float
float_array
[
32
];
}
float_values
{};
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
// TODO: pack in a loop
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
return
float_values
.
float32_array
;
#endif
}
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
e8m0_bexp_t
>
(
e8m0_bexp_t
scale
)
{
return
utils
::
cast_to_float
(
scale
);
}
template
<
>
inline
__host__
__device__
e8m0_bexp_t
type_convert
<
e8m0_bexp_t
,
float
>
(
float
scale
)
{
return
utils
::
cast_from_float
(
scale
);
}
// Declare a template function for scaled conversion
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
scaled_type_convert
(
e8m0_bexp_t
scale
,
X
x
);
// convert fp4 to fp32
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f4_t
>
(
e8m0_bexp_t
scale
,
f4_t
x
)
{
#if defined(__gfx950__)
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{};
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
type_convert
<
float
>
(
scale
),
0
);
return
float_values
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f4_t
>
(
scale
,
x
);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template
<
>
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
f4x2_t
>
(
e8m0_bexp_t
scale
,
f4x2_t
x
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{};
value
.
f4x2_array
[
0
]
=
x
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
#else
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
()),
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
())};
return
ret
;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f4x32_t
>
(
e8m0_bexp_t
scale
,
f4x32_t
x
)
{
#if defined(__gfx950__)
union
{
f4x32_t
f4x32_array
;
f4x2_t
fp4x2
[
16
];
}
value
{
x
};
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
bitwise_value
{};
float2_t
op
;
float32_t
ret
;
// TODO: pack in a loop
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
0
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
0
]
=
op
[
0
];
ret
[
1
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
2
]
=
op
[
0
];
ret
[
3
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
2
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
4
]
=
op
[
0
];
ret
[
5
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
3
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
6
]
=
op
[
0
];
ret
[
7
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
4
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
8
]
=
op
[
0
];
ret
[
9
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
5
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
10
]
=
op
[
0
];
ret
[
11
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
6
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
12
]
=
op
[
0
];
ret
[
13
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
7
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
14
]
=
op
[
0
];
ret
[
15
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
8
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
16
]
=
op
[
0
];
ret
[
17
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
9
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
18
]
=
op
[
0
];
ret
[
19
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
10
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
20
]
=
op
[
0
];
ret
[
21
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
11
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
22
]
=
op
[
0
];
ret
[
23
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
12
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
24
]
=
op
[
0
];
ret
[
25
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
13
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
26
]
=
op
[
0
];
ret
[
27
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
14
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
28
]
=
op
[
0
];
ret
[
29
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
15
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
30
]
=
op
[
0
];
ret
[
31
]
=
op
[
1
];
return
ret
;
#else
union
{
float32_t
float32_array
;
float
float_array
[
32
];
}
float_values
{};
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
// TODO: pack in a loop
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
return
float_values
.
float32_array
;
#endif
}
// convert fp32 to fp4
template
<
>
inline
__host__
__device__
f4_t
scaled_type_convert
<
f4_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template
<
>
inline
__host__
__device__
f4x2_t
scaled_type_convert
<
f4x2_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template
<
>
inline
__host__
__device__
f4x32_t
scaled_type_convert
<
f4x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
template
<
typename
Y
,
typename
X
,
std
::
size_t
NumElems
>
template
<
typename
Y
,
typename
X
,
std
::
size_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
std
::
array
<
Y
,
NumElems
>&
y
,
inline
__host__
__device__
void
array_convert
(
std
::
array
<
Y
,
NumElems
>&
y
,
const
std
::
array
<
X
,
NumElems
>&
x
)
const
std
::
array
<
X
,
NumElems
>&
x
)
...
...
library/include/ck/library/utility/check_err.hpp
View file @
cde0f480
...
@@ -26,6 +26,7 @@ namespace utils {
...
@@ -26,6 +26,7 @@ namespace utils {
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_relative_threshold
(
const
int
number_of_accumulations
=
1
)
double
get_relative_threshold
(
const
int
number_of_accumulations
=
1
)
{
{
using
F4
=
ck
::
f4_t
;
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
...
@@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
using
I8
=
int8_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
I32
=
int32_t
;
static_assert
(
is_same_v
<
ComputeDataType
,
F
8
>
||
is_same_v
<
ComputeDataType
,
F
16
>
||
static_assert
(
is_same_v
<
ComputeDataType
,
F
4
>
||
is_same_v
<
ComputeDataType
,
F
8
>
||
is_same_v
<
ComputeDataType
,
B
F16
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
F16
>
||
is_same_v
<
ComputeDataType
,
BF16
>
||
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I
32
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
I
8
>
||
is_same_v
<
ComputeDataType
,
int
>
,
is_same_v
<
ComputeDataType
,
I32
>
||
is_same_v
<
ComputeDataType
,
int
>
,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
double
compute_error
=
0
;
double
compute_error
=
0
;
if
constexpr
(
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I32
>
||
if
constexpr
(
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I32
>
||
...
@@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error
=
std
::
pow
(
2
,
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
compute_error
=
std
::
pow
(
2
,
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
}
}
static_assert
(
is_same_v
<
OutDataType
,
F
8
>
||
is_same_v
<
OutDataType
,
F
16
>
||
static_assert
(
is_same_v
<
OutDataType
,
F
4
>
||
is_same_v
<
OutDataType
,
F
8
>
||
is_same_v
<
OutDataType
,
B
F16
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
F16
>
||
is_same_v
<
OutDataType
,
BF16
>
||
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I
32
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
I
8
>
||
is_same_v
<
OutDataType
,
int
>
,
is_same_v
<
OutDataType
,
I32
>
||
is_same_v
<
OutDataType
,
int
>
,
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
double
output_error
=
0
;
double
output_error
=
0
;
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
...
@@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
}
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_same_v
<
AccDataType
,
F
8
>
||
is_same_v
<
AccDataType
,
F
16
>
||
static_assert
(
is_same_v
<
AccDataType
,
F
4
>
||
is_same_v
<
AccDataType
,
F
8
>
||
is_same_v
<
AccDataType
,
B
F16
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
F16
>
||
is_same_v
<
AccDataType
,
BF16
>
||
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I
32
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
I
8
>
||
is_same_v
<
AccDataType
,
int
>
,
is_same_v
<
AccDataType
,
I32
>
||
is_same_v
<
AccDataType
,
int
>
,
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
double
acc_error
=
0
;
double
acc_error
=
0
;
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
...
@@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number_of_accumulations
=
1
)
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number_of_accumulations
=
1
)
{
{
using
F4
=
ck
::
f4_t
;
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
...
@@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
...
@@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
using
I8
=
int8_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
I32
=
int32_t
;
static_assert
(
is_same_v
<
ComputeDataType
,
F
8
>
||
is_same_v
<
ComputeDataType
,
F
16
>
||
static_assert
(
is_same_v
<
ComputeDataType
,
F
4
>
||
is_same_v
<
ComputeDataType
,
F
8
>
||
is_same_v
<
ComputeDataType
,
B
F16
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
F16
>
||
is_same_v
<
ComputeDataType
,
BF16
>
||
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I
32
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
I
8
>
||
is_same_v
<
ComputeDataType
,
int
>
,
is_same_v
<
ComputeDataType
,
I32
>
||
is_same_v
<
ComputeDataType
,
int
>
,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
double
compute_error
=
0
;
double
compute_error
=
0
;
...
@@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
...
@@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error
=
std
::
pow
(
2
,
expo
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
compute_error
=
std
::
pow
(
2
,
expo
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
}
}
static_assert
(
is_same_v
<
OutDataType
,
F
8
>
||
is_same_v
<
OutDataType
,
F
16
>
||
static_assert
(
is_same_v
<
OutDataType
,
F
4
>
||
is_same_v
<
OutDataType
,
F
8
>
||
is_same_v
<
OutDataType
,
B
F16
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
F16
>
||
is_same_v
<
OutDataType
,
BF16
>
||
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I
32
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
I
8
>
||
is_same_v
<
OutDataType
,
int
>
,
is_same_v
<
OutDataType
,
I32
>
||
is_same_v
<
OutDataType
,
int
>
,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
double
output_error
=
0
;
double
output_error
=
0
;
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
...
@@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
...
@@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
}
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_same_v
<
AccDataType
,
F
8
>
||
is_same_v
<
AccDataType
,
F
16
>
||
static_assert
(
is_same_v
<
AccDataType
,
F
4
>
||
is_same_v
<
AccDataType
,
F
8
>
||
is_same_v
<
AccDataType
,
B
F16
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
F16
>
||
is_same_v
<
AccDataType
,
BF16
>
||
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I
32
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
I
8
>
||
is_same_v
<
AccDataType
,
int
>
,
is_same_v
<
AccDataType
,
I32
>
||
is_same_v
<
AccDataType
,
int
>
,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
double
acc_error
=
0
;
double
acc_error
=
0
;
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
...
@@ -450,5 +452,54 @@ check_err(const Range& out,
...
@@ -450,5 +452,54 @@ check_err(const Range& out,
return
res
;
return
res
;
}
}
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f4_t
>
),
bool
>
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
0.5
,
double
atol
=
0.5
)
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
bool
res
{
true
};
int
err_count
=
0
;
double
err
=
0
;
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
))
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
}
}
if
(
!
res
)
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
" number of errors: "
<<
err_count
<<
std
::
endl
;
}
return
res
;
}
}
// namespace utils
}
// namespace utils
}
// namespace ck
}
// namespace ck
library/include/ck/library/utility/host_tensor_generator.hpp
View file @
cde0f480
...
@@ -69,6 +69,18 @@ struct GeneratorTensor_1<ck::f8_t>
...
@@ -69,6 +69,18 @@ struct GeneratorTensor_1<ck::f8_t>
};
};
#endif
#endif
template
<
>
struct
GeneratorTensor_1
<
ck
::
f4_t
>
{
float
value
=
1.0
;
template
<
typename
...
Is
>
ck
::
f4_t
operator
()(
Is
...)
{
return
ck
::
type_convert
<
ck
::
f4_t
>
(
value
);
}
};
template
<
>
template
<
>
struct
GeneratorTensor_1
<
int8_t
>
struct
GeneratorTensor_1
<
int8_t
>
{
{
...
@@ -153,6 +165,20 @@ struct GeneratorTensor_2<ck::bf8_t>
...
@@ -153,6 +165,20 @@ struct GeneratorTensor_2<ck::bf8_t>
};
};
#endif
#endif
template
<
>
struct
GeneratorTensor_2
<
ck
::
f4_t
>
{
int
min_value
=
0
;
int
max_value
=
1
;
template
<
typename
...
Is
>
ck
::
f4_t
operator
()(
Is
...)
{
float
tmp
=
(
std
::
rand
()
%
(
max_value
-
min_value
))
+
min_value
;
return
ck
::
type_convert
<
ck
::
f4_t
>
(
tmp
);
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
GeneratorTensor_3
struct
GeneratorTensor_3
{
{
...
@@ -223,6 +249,23 @@ struct GeneratorTensor_3<ck::bf8_t>
...
@@ -223,6 +249,23 @@ struct GeneratorTensor_3<ck::bf8_t>
};
};
#endif
#endif
template
<
>
struct
GeneratorTensor_3
<
ck
::
f4_t
>
{
float
min_value
=
0
;
float
max_value
=
1
;
template
<
typename
...
Is
>
ck
::
f4_t
operator
()(
Is
...)
{
float
tmp
=
float
(
std
::
rand
())
/
float
(
RAND_MAX
);
float
fp32_tmp
=
min_value
+
tmp
*
(
max_value
-
min_value
);
return
ck
::
type_convert
<
ck
::
f4_t
>
(
fp32_tmp
);
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
GeneratorTensor_4
struct
GeneratorTensor_4
{
{
...
...
test/data_type/CMakeLists.txt
View file @
cde0f480
...
@@ -42,6 +42,10 @@ if (CK_USE_FNUZ_FP8)
...
@@ -42,6 +42,10 @@ if (CK_USE_FNUZ_FP8)
add_dependencies
(
test_fp8 test_fp8_fnuz
)
add_dependencies
(
test_fp8 test_fp8_fnuz
)
add_dependencies
(
test_fp8 test_bf8_fnuz
)
add_dependencies
(
test_fp8 test_bf8_fnuz
)
endif
()
endif
()
add_gtest_executable
(
test_fp4 test_fp4.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_fp4 PRIVATE utility
)
endif
()
add_gtest_executable
(
test_custom_type test_custom_type.cpp
)
add_gtest_executable
(
test_custom_type test_custom_type.cpp
)
if
(
result EQUAL 0
)
if
(
result EQUAL 0
)
...
...
test/data_type/test_fp4.cpp
0 → 100644
View file @
cde0f480
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
e8m0_bexp_t
;
using
ck
::
f4_convert_rne
;
using
ck
::
f4_convert_sr
;
using
ck
::
f4_t
;
using
ck
::
f4x2_pk_t
;
using
ck
::
Number
;
using
ck
::
scaled_type_convert
;
using
ck
::
type_convert
;
using
ck
::
vector_type
;
using
ck
::
utils
::
cast_from_float
;
using
ck
::
utils
::
cast_to_float
;
TEST
(
FP4
,
NumericLimits
)
{
// constants given for negative zero nan mode
EXPECT_EQ
(
ck
::
NumericLimits
<
f4_t
>::
Min
(),
f4_t
{
0x2
});
EXPECT_EQ
(
ck
::
NumericLimits
<
f4_t
>::
Max
(),
f4_t
{
0x7
});
EXPECT_EQ
(
ck
::
NumericLimits
<
f4_t
>::
Lowest
(),
f4_t
{
0xF
});
EXPECT_EQ
(
ck
::
NumericLimits
<
f4_t
>::
MinSubnorm
(),
f4_t
{
0x1
});
EXPECT_EQ
(
ck
::
NumericLimits
<
f4_t
>::
MaxSubnorm
(),
f4_t
{
0x1
});
}
TEST
(
FP4
,
ConvertFP32Nearest
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// set maximum fp4 value
float
max_fp4
=
6.0
f
;
// convert 0 float to fp4 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f4_convert_rne
(
0.0
f
)),
abs_tol
);
// convert maximal f4_t to float and check if equal to 6.0
ASSERT_NEAR
(
max_fp4
,
type_convert
<
float
>
(
f4_convert_rne
(
max_fp4
)),
abs_tol
);
// convert maximal float to fp4 and back, check if clipped to 6.0
ASSERT_NEAR
(
max_fp4
,
type_convert
<
float
>
(
f4_convert_rne
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
// positive norm float value to fp4 and back, check if holds
float
pos_float
=
1.0
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f4_convert_rne
(
pos_float
)),
abs_tol
);
// negative norm float value to fp4 and back, check if holds
float
neg_float
=
-
1.5
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f4_convert_rne
(
neg_float
)),
abs_tol
);
// positive subnorm float value to fp4 and back, check if holds
pos_float
=
0.5
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f4_convert_rne
(
pos_float
)),
abs_tol
);
// negative subnorm float value to fp4 and back, check if holds
neg_float
=
-
0.5
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f4_convert_rne
(
neg_float
)),
abs_tol
);
}
TEST
(
FP4
,
ConvertFP32Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// set maximum fp4 value
float
max_fp4
=
6.0
f
;
// convert 0 float to fp4 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f4_convert_sr
(
0.0
f
)),
abs_tol
);
// convert maximal f4_t to float and check if equal to 6.0
ASSERT_NEAR
(
max_fp4
,
type_convert
<
float
>
(
f4_convert_sr
(
max_fp4
)),
abs_tol
);
// convert maximal float to fp4 and back, check if clipped to 6.0
ASSERT_NEAR
(
max_fp4
,
type_convert
<
float
>
(
f4_convert_sr
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
// positive norm float value to fp4 and back, check if holds
float
pos_float
=
1.0
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f4_convert_sr
(
pos_float
)),
abs_tol
);
// negative norm float value to fp4 and back, check if holds
float
neg_float
=
-
1.5
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f4_convert_sr
(
neg_float
)),
abs_tol
);
// positive subnorm float value to fp4 and back, check if holds
pos_float
=
0.5
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f4_convert_sr
(
pos_float
)),
abs_tol
);
// negative subnorm float value to fp4 and back, check if holds
neg_float
=
-
0.5
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f4_convert_sr
(
neg_float
)),
abs_tol
);
}
TEST
(
FP4
,
ScaledConvertFP32Nearest
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// set maximum fp4 value
float
max_fp4
=
6.0
f
;
// set maximum scale
float
max_scale
=
std
::
pow
(
2
,
ck
::
NumericLimits
<
e8m0_bexp_t
>::
Max
().
data
-
ck
::
NumericUtils
<
e8m0_bexp_t
>::
bias
);
// 0xFE -> float
// set minimum scale
float
min_scale
=
std
::
pow
(
2
,
-
ck
::
NumericUtils
<
e8m0_bexp_t
>::
bias
);
// 0x00 -> float
// set arbitrary scale to 256.0
float
test_scale
=
256.0
f
;
// 0b10000111
// convert 0 float to fp4 and back with maximal scale, check if holds
ASSERT_NEAR
(
0.0
f
,
scaled_type_convert
<
float
>
(
cast_from_float
(
max_scale
),
f4_convert_rne
(
0.0
f
)),
abs_tol
);
// convert 0 float to fp4 and back with minimal scale, check if holds
ASSERT_NEAR
(
0.0
f
,
scaled_type_convert
<
float
>
(
cast_from_float
(
min_scale
),
f4_convert_rne
(
0.0
f
)),
abs_tol
);
// convert maximal f4_t with minimal scale to float and check if equal to minimal float
ASSERT_NEAR
(
ck
::
NumericLimits
<
float
>::
Min
(),
scaled_type_convert
<
float
>
(
cast_from_float
(
min_scale
),
f4_convert_rne
(
max_fp4
)),
abs_tol
);
// positive norm float value to fp4 and back with various scales, check if holds
float
pos_float
=
1.0
f
;
ASSERT_NEAR
(
pos_float
*
test_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
test_scale
),
f4_convert_rne
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
*
max_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
max_scale
),
f4_convert_rne
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
*
min_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
min_scale
),
f4_convert_rne
(
pos_float
)),
abs_tol
);
// negative norm float value to fp4 and back with various scales, check if holds
float
neg_float
=
-
1.5
f
;
ASSERT_NEAR
(
neg_float
*
test_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
test_scale
),
f4_convert_rne
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
*
max_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
max_scale
),
f4_convert_rne
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
*
min_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
min_scale
),
f4_convert_rne
(
neg_float
)),
abs_tol
);
// positive subnorm float value to fp4 and back with various scales, check if holds
pos_float
=
0.5
f
;
ASSERT_NEAR
(
pos_float
*
test_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
test_scale
),
f4_convert_rne
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
*
max_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
max_scale
),
f4_convert_rne
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
*
min_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
min_scale
),
f4_convert_rne
(
pos_float
)),
abs_tol
);
// negative subnorm float value to fp4 and back with various scales, check if holds
neg_float
=
-
0.5
f
;
ASSERT_NEAR
(
neg_float
*
test_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
test_scale
),
f4_convert_rne
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
*
max_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
max_scale
),
f4_convert_rne
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
*
min_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
min_scale
),
f4_convert_rne
(
neg_float
)),
abs_tol
);
}
TEST
(
FP4
,
ScaledConvertFP32Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// set maximum fp4 value
float
max_fp4
=
6.0
f
;
// set maximum scale
float
max_scale
=
std
::
pow
(
2
,
ck
::
NumericLimits
<
e8m0_bexp_t
>::
Max
().
data
-
ck
::
NumericUtils
<
e8m0_bexp_t
>::
bias
);
// 0xFE -> float
// set minimum scale
float
min_scale
=
std
::
pow
(
2
,
-
ck
::
NumericUtils
<
e8m0_bexp_t
>::
bias
);
// 0x00 -> float
// set arbitrary scale to 256.0
float
test_scale
=
256.0
f
;
// 0b10000111
// convert 0 float to fp4 and back with maximal scale, check if holds
ASSERT_NEAR
(
0.0
f
,
scaled_type_convert
<
float
>
(
cast_from_float
(
max_scale
),
f4_convert_sr
(
0.0
f
)),
abs_tol
);
// convert 0 float to fp4 and back with minimal scale, check if holds
ASSERT_NEAR
(
0.0
f
,
scaled_type_convert
<
float
>
(
cast_from_float
(
min_scale
),
f4_convert_sr
(
0.0
f
)),
abs_tol
);
// convert maximal f4_t with minimal scale to float and check if equal to minimal float
ASSERT_NEAR
(
ck
::
NumericLimits
<
float
>::
Min
(),
scaled_type_convert
<
float
>
(
cast_from_float
(
min_scale
),
f4_convert_sr
(
max_fp4
)),
abs_tol
);
// positive norm float value to fp4 and back with various scales, check if holds
float
pos_float
=
1.0
f
;
ASSERT_NEAR
(
pos_float
*
test_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
test_scale
),
f4_convert_sr
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
*
max_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
max_scale
),
f4_convert_sr
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
*
min_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
min_scale
),
f4_convert_sr
(
pos_float
)),
abs_tol
);
// negative norm float value to fp4 and back with various scales, check if holds
float
neg_float
=
-
1.5
f
;
ASSERT_NEAR
(
neg_float
*
test_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
test_scale
),
f4_convert_sr
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
*
max_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
max_scale
),
f4_convert_sr
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
*
min_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
min_scale
),
f4_convert_sr
(
neg_float
)),
abs_tol
);
// positive subnorm float value to fp4 and back with various scales, check if holds
pos_float
=
0.5
f
;
ASSERT_NEAR
(
pos_float
*
test_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
test_scale
),
f4_convert_sr
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
*
max_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
max_scale
),
f4_convert_sr
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
*
min_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
min_scale
),
f4_convert_sr
(
pos_float
)),
abs_tol
);
// negative subnorm float value to fp4 and back with various scales, check if holds
neg_float
=
-
0.5
f
;
ASSERT_NEAR
(
neg_float
*
test_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
test_scale
),
f4_convert_sr
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
*
max_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
max_scale
),
f4_convert_sr
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
*
min_scale
,
scaled_type_convert
<
float
>
(
cast_from_float
(
min_scale
),
f4_convert_sr
(
neg_float
)),
abs_tol
);
}
TEST
(
FP4
,
TestSize
)
{
ASSERT_EQ
(
1
,
sizeof
(
f4x2_pk_t
));
ASSERT_EQ
(
1
,
sizeof
(
vector_type
<
f4x2_pk_t
,
1
>
));
ASSERT_EQ
(
2
,
sizeof
(
vector_type
<
f4x2_pk_t
,
2
>
));
ASSERT_EQ
(
4
,
sizeof
(
vector_type
<
f4x2_pk_t
,
4
>
));
ASSERT_EQ
(
8
,
sizeof
(
vector_type
<
f4x2_pk_t
,
8
>
));
ASSERT_EQ
(
16
,
sizeof
(
vector_type
<
f4x2_pk_t
,
16
>
));
ASSERT_EQ
(
32
,
sizeof
(
vector_type
<
f4x2_pk_t
,
32
>
));
}
TEST
(
FP4
,
TestAlignment
)
{
ASSERT_EQ
(
1
,
alignof
(
f4x2_pk_t
));
ASSERT_EQ
(
1
,
alignof
(
vector_type
<
f4x2_pk_t
,
1
>
));
ASSERT_EQ
(
2
,
alignof
(
vector_type
<
f4x2_pk_t
,
2
>
));
ASSERT_EQ
(
4
,
alignof
(
vector_type
<
f4x2_pk_t
,
4
>
));
ASSERT_EQ
(
8
,
alignof
(
vector_type
<
f4x2_pk_t
,
8
>
));
ASSERT_EQ
(
16
,
alignof
(
vector_type
<
f4x2_pk_t
,
16
>
));
ASSERT_EQ
(
32
,
alignof
(
vector_type
<
f4x2_pk_t
,
32
>
));
}
// test vector of 1 f4x2_pk_t, contains 2 f4_t
TEST
(
FP4
,
TestAsType1
)
{
// test size
const
int
size
=
1
;
std
::
vector
<
f4x2_pk_t
::
type
>
test_vec
=
{
f4x2_pk_t
::
type
{
0b0010
},
f4x2_pk_t
::
type
{
0b1001
}};
// reference vector
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
0
);
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{})
=
f4x2_pk_t
{}.
pack
(
test_vec
.
at
(
i
),
test_vec
.
at
(
i
+
1
));
});
// copy the vector
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
test_vec
.
at
(
i
));
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
test_vec
.
at
(
i
+
1
));
});
}
// test vector of 2 f4x2_pk_t, contains 4 f4_t
TEST
(
FP4
,
TestAsType2
)
{
// test size
const
int
size
=
2
;
std
::
vector
<
f4x2_pk_t
::
type
>
test_vec
=
{
f4x2_pk_t
::
type
{
0b0010
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0111
}};
// reference vector
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
0
);
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{})
=
f4x2_pk_t
{}.
pack
(
test_vec
.
at
(
i
),
test_vec
.
at
(
i
+
1
));
});
// copy the vector
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
test_vec
.
at
(
i
));
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
test_vec
.
at
(
i
+
1
));
});
}
// test vector of 4 f4x2_pk_t, contains 8 f4_t
TEST
(
FP4
,
TestAsType4
)
{
// test size
const
int
size
=
4
;
std
::
vector
<
f4x2_pk_t
::
type
>
test_vec
=
{
f4x2_pk_t
::
type
{
0b0010
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0111
},
f4x2_pk_t
::
type
{
0b1010
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1111
}};
// reference vector
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
0
);
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{})
=
f4x2_pk_t
{}.
pack
(
test_vec
.
at
(
i
),
test_vec
.
at
(
i
+
1
));
});
// copy the vector
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
test_vec
.
at
(
i
));
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
test_vec
.
at
(
i
+
1
));
});
}
// test vector of 8 f4x2_pk_t, contains 16 f4_t
TEST
(
FP4
,
TestAsType8
)
{
// test size
const
int
size
=
8
;
std
::
vector
<
f4x2_pk_t
::
type
>
test_vec
=
{
f4x2_pk_t
::
type
{
0b0010
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0111
},
f4x2_pk_t
::
type
{
0b1010
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1111
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0111
},
f4x2_pk_t
::
type
{
0b1010
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0010
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1111
}};
// reference vector
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
0
);
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{})
=
f4x2_pk_t
{}.
pack
(
test_vec
.
at
(
i
),
test_vec
.
at
(
i
+
1
));
});
// copy the vector
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
test_vec
.
at
(
i
));
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
test_vec
.
at
(
i
+
1
));
});
}
// test vector of 16 f4x2_pk_t, contains 32 f4_t
TEST
(
FP4
,
TestAsType16
)
{
// test size
const
int
size
=
16
;
std
::
vector
<
f4x2_pk_t
::
type
>
test_vec
=
{
f4x2_pk_t
::
type
{
0b0010
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0111
},
f4x2_pk_t
::
type
{
0b1010
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1111
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0111
},
f4x2_pk_t
::
type
{
0b1010
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0010
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1111
},
f4x2_pk_t
::
type
{
0b0010
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0111
},
f4x2_pk_t
::
type
{
0b1010
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1111
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0111
},
f4x2_pk_t
::
type
{
0b1010
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0010
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1111
}};
// reference vector
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
0
);
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{})
=
f4x2_pk_t
{}.
pack
(
test_vec
.
at
(
i
),
test_vec
.
at
(
i
+
1
));
});
// copy the vector
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
test_vec
.
at
(
i
));
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
test_vec
.
at
(
i
+
1
));
});
}
// test vector of 32 f4x2_pk_t, contains 64 f4_t
TEST
(
FP4
,
TestAsType32
)
{
// test size
const
int
size
=
32
;
std
::
vector
<
f4x2_pk_t
::
type
>
test_vec
=
{
f4x2_pk_t
::
type
{
0b0010
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0111
},
f4x2_pk_t
::
type
{
0b1010
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1111
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0111
},
f4x2_pk_t
::
type
{
0b1010
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0010
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1111
},
f4x2_pk_t
::
type
{
0b0010
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0111
},
f4x2_pk_t
::
type
{
0b1010
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1111
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0111
},
f4x2_pk_t
::
type
{
0b1010
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0010
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1111
},
f4x2_pk_t
::
type
{
0b0010
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0111
},
f4x2_pk_t
::
type
{
0b1010
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1111
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0111
},
f4x2_pk_t
::
type
{
0b1010
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0010
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1111
},
f4x2_pk_t
::
type
{
0b0010
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0111
},
f4x2_pk_t
::
type
{
0b1010
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1111
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0111
},
f4x2_pk_t
::
type
{
0b1010
},
f4x2_pk_t
::
type
{
0b0001
},
f4x2_pk_t
::
type
{
0b0010
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1001
},
f4x2_pk_t
::
type
{
0b1111
}};
// reference vector
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
0
);
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{})
=
f4x2_pk_t
{}.
pack
(
test_vec
.
at
(
i
),
test_vec
.
at
(
i
+
1
));
});
// copy the vector
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
test_vec
.
at
(
i
));
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
test_vec
.
at
(
i
+
1
));
});
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment