Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1504c3e8
Unverified
Commit
1504c3e8
authored
Nov 22, 2024
by
Illia Silin
Committed by
GitHub
Nov 22, 2024
Browse files
Merge pull request #219 from ROCm/andriy/lwpck-2430
Add support of OCP FP8 data types in CK for gfx950 arch
parents
b6f7cddd
27a05c7e
Changes
61
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1037 additions
and
306 deletions
+1037
-306
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+247
-77
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+2
-2
include/ck/utility/random_gen.hpp
include/ck/utility/random_gen.hpp
+8
-5
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+128
-60
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+5
-5
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+1
-1
library/include/ck/library/utility/host_tensor_generator.hpp
library/include/ck/library/utility/host_tensor_generator.hpp
+25
-6
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+2
-2
library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f8_instance.cpp
...pu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f8_instance.cpp
+2
-2
profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
...r/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
+2
-2
profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp
profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp
+2
-2
profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
...clude/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
+2
-2
profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
...ofiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
+2
-2
profiler/include/profiler/profile_gemm_impl.hpp
profiler/include/profiler/profile_gemm_impl.hpp
+3
-3
test/CMakeLists.txt
test/CMakeLists.txt
+1
-1
test/data_type/CMakeLists.txt
test/data_type/CMakeLists.txt
+31
-6
test/data_type/test_bf8_fnuz.cpp
test/data_type/test_bf8_fnuz.cpp
+73
-62
test/data_type/test_bf8_ocp.cpp
test/data_type/test_bf8_ocp.cpp
+268
-0
test/data_type/test_custom_type.cpp
test/data_type/test_custom_type.cpp
+150
-0
test/data_type/test_fp8_fnuz.cpp
test/data_type/test_fp8_fnuz.cpp
+83
-66
No files found.
include/ck/utility/data_type.hpp
View file @
1504c3e8
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#pragma once
#pragma once
#include "ck/utility/amd_ck_fp8.hpp"
#include "ck/utility/statically_indexed_array.hpp"
#include "ck/utility/statically_indexed_array.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -10,8 +11,6 @@ namespace ck {
...
@@ -10,8 +11,6 @@ namespace ck {
using
bhalf_t
=
ushort
;
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
half_t
=
_Float16
;
using
int4_t
=
_BitInt
(
4
);
using
int4_t
=
_BitInt
(
4
);
using
f8_t
=
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
{
{
...
@@ -19,14 +18,15 @@ inline constexpr auto next_pow2(uint32_t x)
...
@@ -19,14 +18,15 @@ inline constexpr auto next_pow2(uint32_t x)
return
x
>
1u
?
(
1u
<<
(
32u
-
__builtin_clz
(
x
-
1u
)))
:
x
;
return
x
>
1u
?
(
1u
<<
(
32u
-
__builtin_clz
(
x
-
1u
)))
:
x
;
}
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool
template
<
typename
T
>
template
<
typename
T
>
inline
constexpr
bool
is_native_type
()
inline
constexpr
bool
is_native_type
()
{
{
return
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
half_t
>::
value
||
return
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
bhalf_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
bhalf_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
uint8_t
>::
value
||
is_same
<
T
,
f8_
t
>::
value
||
is_same
<
T
,
bf8
_t
>::
value
||
is_same
<
T
,
uint8_t
>::
value
||
is_same
<
T
,
f8_
fnuz
_t
>::
value
||
is_same
<
T
,
bool
>::
value
;
is_same
<
T
,
bf8_fnuz_t
>::
value
||
is_same
<
T
,
bool
>::
value
;
}
}
// vector_type
// vector_type
...
@@ -166,16 +166,30 @@ struct scalar_type<int4_t>
...
@@ -166,16 +166,30 @@ struct scalar_type<int4_t>
#endif
#endif
template
<
>
template
<
>
struct
scalar_type
<
f8_t
>
struct
scalar_type
<
f8_
fnuz_
t
>
{
{
using
type
=
f8_t
;
using
type
=
f8_
fnuz_
t
;
static
constexpr
index_t
vector_size
=
1
;
static
constexpr
index_t
vector_size
=
1
;
};
};
template
<
>
template
<
>
struct
scalar_type
<
bf8_t
>
struct
scalar_type
<
bf8_
fnuz_
t
>
{
{
using
type
=
bf8_t
;
using
type
=
bf8_fnuz_t
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
struct
scalar_type
<
f8_ocp_t
>
{
using
type
=
f8_ocp_t
::
data_type
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
struct
scalar_type
<
bf8_ocp_t
>
{
using
type
=
bf8_ocp_t
::
data_type
;
static
constexpr
index_t
vector_size
=
1
;
static
constexpr
index_t
vector_size
=
1
;
};
};
...
@@ -1023,47 +1037,83 @@ struct non_native_vector_base
...
@@ -1023,47 +1037,83 @@ struct non_native_vector_base
T
d
[
N
];
T
d
[
N
];
};
};
template
<
typename
T
,
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
T
,
N
>>
;
template
<
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
f8_ocp_t
,
N
>>
{
using
type
=
typename
non_native_vector_base
<
f8_ocp_t
,
N
>::
data_t
;
static
constexpr
index_t
vector_size
=
N
;
};
template
<
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
bf8_ocp_t
,
N
>>
{
using
type
=
typename
non_native_vector_base
<
bf8_ocp_t
,
N
>::
data_t
;
static
constexpr
index_t
vector_size
=
N
;
};
// non-native vector_type implementation
// non-native vector_type implementation
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
using
type
=
d1_t
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
using
type
=
d1_nnv_t
;
union
alignas
(
next_pow2
(
1
*
sizeof
(
T
)))
union
alignas
(
next_pow2
(
1
*
sizeof
(
T
)))
{
{
d1_t
d1_
;
d1_t
d1_
;
StaticallyIndexedArray
<
d1_t
,
1
>
d1x1_
;
StaticallyIndexedArray
<
d1_t
,
1
>
d1x1_
;
d1_nnv_t
d1_nnv_
;
}
data_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
()
:
data_
{
d1_t
{}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
{
v
}
}
{}
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
,
"Something went wrong, please check src and dst types."
);
"Something went wrong, please check src and dst types."
);
return
data_
.
d1x1_
;
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x1_
;
}
else
{
return
err
;
}
}
}
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
__host__
__device__
constexpr
auto
&
AsType
()
{
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
,
"Something went wrong, please check src and dst types."
);
"Something went wrong, please check src and dst types."
);
return
data_
.
d1x1_
;
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x1_
;
}
else
{
return
err
;
}
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
type
=
d2_t
;
using
type
=
d2_t
;
...
@@ -1081,10 +1131,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
...
@@ -1081,10 +1131,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"Something went wrong, please check src and dst types."
);
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
{
return
data_
.
d1x2_
;
return
data_
.
d1x2_
;
}
}
...
@@ -1101,10 +1152,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
...
@@ -1101,10 +1152,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
__host__
__device__
constexpr
auto
&
AsType
()
{
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"Something went wrong, please check src and dst types."
);
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
{
return
data_
.
d1x2_
;
return
data_
.
d1x2_
;
}
}
...
@@ -1122,9 +1174,10 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
...
@@ -1122,9 +1174,10 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
type
=
d4_t
;
using
type
=
d4_t
;
...
@@ -1143,10 +1196,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
...
@@ -1143,10 +1196,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"Something went wrong, please check src and dst types."
);
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
{
return
data_
.
d1x4_
;
return
data_
.
d1x4_
;
}
}
...
@@ -1167,10 +1221,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
...
@@ -1167,10 +1221,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
__host__
__device__
constexpr
auto
&
AsType
()
{
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"Something went wrong, please check src and dst types."
);
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
{
return
data_
.
d1x4_
;
return
data_
.
d1x4_
;
}
}
...
@@ -1192,10 +1247,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
...
@@ -1192,10 +1247,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
using
type
=
d8_t
;
using
type
=
d8_t
;
...
@@ -1215,11 +1271,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
...
@@ -1215,11 +1271,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"Something went wrong, please check src and dst types."
);
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
{
return
data_
.
d1x8_
;
return
data_
.
d1x8_
;
}
}
...
@@ -1244,11 +1301,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
...
@@ -1244,11 +1301,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
__host__
__device__
constexpr
auto
&
AsType
()
{
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"Something went wrong, please check src and dst types."
);
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
{
return
data_
.
d1x8_
;
return
data_
.
d1x8_
;
}
}
...
@@ -1274,11 +1332,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
...
@@ -1274,11 +1332,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d16_t
=
non_native_vector_base
<
T
,
16
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
using
d16_t
=
non_native_vector_base
<
T
,
16
>
;
using
type
=
d16_t
;
using
type
=
d16_t
;
...
@@ -1299,12 +1358,12 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
...
@@ -1299,12 +1358,12 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d
2
_t
>::
value
||
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d
1_nnv
_t
>::
value
||
is_same
<
X
,
d
4
_t
>::
value
||
is_same
<
X
,
d
8
_t
>::
value
||
is_same
<
X
,
d
2
_t
>::
value
||
is_same
<
X
,
d
4
_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
"Something went wrong, please check src and dst types."
);
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
{
return
data_
.
d1x16_
;
return
data_
.
d1x16_
;
}
}
...
@@ -1333,12 +1392,12 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
...
@@ -1333,12 +1392,12 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
__host__
__device__
constexpr
auto
&
AsType
()
{
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d
2
_t
>::
value
||
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d
1_nnv
_t
>::
value
||
is_same
<
X
,
d
4
_t
>::
value
||
is_same
<
X
,
d
8
_t
>::
value
||
is_same
<
X
,
d
2
_t
>::
value
||
is_same
<
X
,
d
4
_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
"Something went wrong, please check src and dst types."
);
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
{
return
data_
.
d1x16_
;
return
data_
.
d1x16_
;
}
}
...
@@ -1632,20 +1691,70 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
...
@@ -1632,20 +1691,70 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// f8
// f8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x2_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x4_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x8_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x16_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x32_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
using
f8x64_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
64
>::
type
;
// bf8
using
bf8x2_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
2
>::
type
;
using
bf8x4_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
4
>::
type
;
using
bf8x8_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
8
>::
type
;
using
bf8x16_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
16
>::
type
;
using
bf8x32_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
32
>::
type
;
using
bf8x64_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
64
>::
type
;
// f8
using
f8x2_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
2
>::
type
;
using
f8x4_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
4
>::
type
;
using
f8x8_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
8
>::
type
;
using
f8x16_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
16
>::
type
;
using
f8x32_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
32
>::
type
;
using
f8x64_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
64
>::
type
;
// bf8
using
bf8x2_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
2
>::
type
;
using
bf8x4_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
4
>::
type
;
using
bf8x8_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
8
>::
type
;
using
bf8x16_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
16
>::
type
;
using
bf8x32_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
32
>::
type
;
using
bf8x64_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
64
>::
type
;
#if CK_FP8_TYPE_OCP
// f8
using
f8x2_t
=
f8x2_ocp_t
;
using
f8x4_t
=
f8x4_ocp_t
;
using
f8x8_t
=
f8x8_ocp_t
;
using
f8x16_t
=
f8x16_ocp_t
;
using
f8x32_t
=
f8x32_ocp_t
;
using
f8x64_t
=
f8x64_ocp_t
;
// bf8
// bf8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x2_t
=
bf8x2_ocp_t
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x4_t
=
bf8x4_ocp_t
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x8_t
=
bf8x8_ocp_t
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x16_t
=
bf8x16_ocp_t
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x32_t
=
bf8x32_ocp_t
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
using
bf8x64_t
=
bf8x64_ocp_t
;
#elif CK_FP8_TYPE_FNUZ
// f8
using
f8x2_t
=
f8x2_fnuz_t
;
using
f8x4_t
=
f8x4_fnuz_t
;
using
f8x8_t
=
f8x8_fnuz_t
;
using
f8x16_t
=
f8x16_fnuz_t
;
using
f8x32_t
=
f8x32_fnuz_t
;
using
f8x64_t
=
f8x64_fnuz_t
;
// bf8
using
bf8x2_t
=
bf8x2_fnuz_t
;
using
bf8x4_t
=
bf8x4_fnuz_t
;
using
bf8x8_t
=
bf8x8_fnuz_t
;
using
bf8x16_t
=
bf8x16_fnuz_t
;
using
bf8x32_t
=
bf8x32_fnuz_t
;
using
bf8x64_t
=
bf8x64_fnuz_t
;
#endif
// u8
// u8
using
uint8x2_t
=
typename
vector_type
<
uint8_t
,
2
>::
type
;
using
uint8x2_t
=
typename
vector_type
<
uint8_t
,
2
>::
type
;
...
@@ -1702,7 +1811,7 @@ struct NumericLimits<int4_t>
...
@@ -1702,7 +1811,7 @@ struct NumericLimits<int4_t>
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
template
<
>
struct
NumericLimits
<
f8_t
>
struct
NumericLimits
<
f8_
fnuz_
t
>
{
{
// negative zero nan mode with exp bias = 8
// negative zero nan mode with exp bias = 8
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
...
@@ -1715,17 +1824,17 @@ struct NumericLimits<f8_t>
...
@@ -1715,17 +1824,17 @@ struct NumericLimits<f8_t>
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
f8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_
fnuz_
t
Min
()
{
return
f8_
fnuz_
t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
f8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_
fnuz_
t
Max
()
{
return
f8_
fnuz_
t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
f8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_
fnuz_
t
Lowest
()
{
return
f8_
fnuz_
t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
f8_
fnuz_
t
QuietNaN
()
{
return
f8_
fnuz_
t
(
binary_qnan
);
}
};
};
template
<
>
template
<
>
struct
NumericLimits
<
bf8_t
>
struct
NumericLimits
<
bf8_
fnuz_
t
>
{
{
// negative zero nan mode with exp bias = 16
// negative zero nan mode with exp bias = 16
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100
...
@@ -1738,13 +1847,59 @@ struct NumericLimits<bf8_t>
...
@@ -1738,13 +1847,59 @@ struct NumericLimits<bf8_t>
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__
__device__
static
constexpr
bf8_t
Min
()
{
return
bf8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_
fnuz_
t
Min
()
{
return
bf8_
fnuz_
t
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_t
Max
()
{
return
bf8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_
fnuz_
t
Max
()
{
return
bf8_
fnuz_
t
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_t
Lowest
()
{
return
bf8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_
fnuz_
t
Lowest
()
{
return
bf8_
fnuz_
t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
bf8_fnuz_t
QuietNaN
()
{
return
bf8_fnuz_t
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
f8_ocp_t
>
{
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000 = 2^-6
static
constexpr
uint8_t
binary_max
=
0x7E
;
// 0b01111110 = 448
static
constexpr
uint8_t
binary_lowest
=
0xFE
;
// 0b11111110 = -448
static
constexpr
uint8_t
binary_qnan
=
0x7F
;
// 0b01111111
__host__
__device__
static
constexpr
f8_ocp_t
Min
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_ocp_t
Max
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_ocp_t
Lowest
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_ocp_t
QuietNaN
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
bf8_ocp_t
>
{
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100 = 2^-14
static
constexpr
uint8_t
binary_max
=
0x7B
;
// 0b01111011 = 57344
static
constexpr
uint8_t
binary_lowest
=
0xFB
;
// 0b11111011 = -57344
static
constexpr
uint8_t
binary_qnan
=
0x7D
;
// 0b01111101
__host__
__device__
static
constexpr
bf8_ocp_t
Min
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
Max
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
Lowest
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
QuietNaN
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_qnan
);
}
};
};
template
<
typename
T
>
template
<
typename
T
>
...
@@ -1787,7 +1942,7 @@ struct NumericUtils<half_t>
...
@@ -1787,7 +1942,7 @@ struct NumericUtils<half_t>
};
};
template
<
>
template
<
>
struct
NumericUtils
<
f8_t
>
struct
NumericUtils
<
f8_
fnuz_
t
>
{
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
mant
=
3
;
...
@@ -1796,13 +1951,28 @@ struct NumericUtils<f8_t>
...
@@ -1796,13 +1951,28 @@ struct NumericUtils<f8_t>
};
};
template
<
>
template
<
>
struct
NumericUtils
<
bf8_t
>
struct
NumericUtils
<
bf8_
fnuz_
t
>
{
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
16
;
// negative zero nan mode
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
// static constexpr int bias = 15; // ieee mode
};
};
template
<
>
struct
NumericUtils
<
f8_ocp_t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
7
;
};
template
<
>
struct
NumericUtils
<
bf8_ocp_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
15
;
};
template
<
>
template
<
>
struct
NumericUtils
<
bhalf_t
>
struct
NumericUtils
<
bhalf_t
>
...
...
include/ck/utility/math_v2.hpp
View file @
1504c3e8
...
@@ -80,7 +80,7 @@ static inline __host__ bool isnan(half_t x)
...
@@ -80,7 +80,7 @@ static inline __host__ bool isnan(half_t x)
return
(
xx
&
0x7FFF
)
>
0x7C00
;
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
};
static
inline
__host__
bool
isnan
(
f8_t
x
)
{
return
(
x
&
0x80
);
};
static
inline
__host__
bool
isnan
(
f8_t
x
)
{
return
ck
::
fp8_is_nan
(
x
);
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static
inline
__host__
bool
isnan
(
int4_t
x
)
static
inline
__host__
bool
isnan
(
int4_t
x
)
...
@@ -531,7 +531,7 @@ static inline __device__ bool isnan(half_t x)
...
@@ -531,7 +531,7 @@ static inline __device__ bool isnan(half_t x)
return
(
xx
&
0x7FFF
)
>
0x7C00
;
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
};
static
inline
__device__
bool
isnan
(
f8_t
x
)
{
return
(
x
&
0x80
);
};
static
inline
__device__
bool
isnan
(
f8_t
x
)
{
return
ck
::
fp8_is_nan
(
x
);
};
static
inline
__device__
half_t
sqrt
(
half_t
x
)
static
inline
__device__
half_t
sqrt
(
half_t
x
)
{
{
...
...
include/ck/utility/random_gen.hpp
View file @
1504c3e8
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/ck.hpp"
namespace
ck
{
namespace
ck
{
// Pseudo random number generator
// Pseudo random number generator
...
@@ -23,7 +25,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
...
@@ -23,7 +25,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
}
// version for fp16
// version for fp16
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<
std
::
is_same
<
half_t
,
T
>{},
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<
std
::
is_same
<
_Float16
,
T
>{},
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
{
uint16_t
x
=
*
(
reinterpret_cast
<
uint16_t
*>
(
&
val
));
uint16_t
x
=
*
(
reinterpret_cast
<
uint16_t
*>
(
&
val
));
...
@@ -38,9 +40,10 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
...
@@ -38,9 +40,10 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
}
// return 0 if data is not fp16 or fp32
// return 0 if data is not fp16 or fp32
template
<
typename
T
,
template
<
uint32_t
seed_t
,
typename
T
,
std
::
enable_if_t
<!
(
std
::
is_same
<
float
,
T
>{}
||
std
::
is_same
<
half_t
,
T
>
{}),
bool
>
=
false
>
uint32_t
seed_t
,
std
::
enable_if_t
<!
(
std
::
is_same
<
float
,
T
>{}
||
std
::
is_same
<
_Float16
,
T
>
{}),
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
int
id
,
T
val
,
uint32_t
seed
=
seed_t
)
__host__
__device__
uint32_t
prand_generator
(
int
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
{
std
::
ignore
=
id
;
std
::
ignore
=
id
;
...
...
include/ck/utility/type_convert.hpp
View file @
1504c3e8
...
@@ -100,6 +100,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
...
@@ -100,6 +100,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
}
template
<
>
inline
__host__
__device__
constexpr
f8_ocp_t
type_convert
<
f8_ocp_t
,
int
>
(
int
x
)
{
return
f8_ocp_t
{
type_convert
<
f8_ocp_t
::
data_type
>
(
x
)};
}
template
<
>
inline
__host__
__device__
constexpr
bf8_ocp_t
type_convert
<
bf8_ocp_t
,
int
>
(
int
x
)
{
return
bf8_ocp_t
{
type_convert
<
bf8_ocp_t
::
data_type
>
(
x
)};
}
// Convert X to Y
// Convert X to Y
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert_sp
(
X
x
)
__host__
__device__
constexpr
Y
type_convert_sp
(
X
x
)
...
@@ -163,7 +175,7 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
...
@@ -163,7 +175,7 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
// convert fp32 to fp8 with stochastic rounding
template
<
>
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_sr
<
f8_
fnuz_
t
,
float
>
(
float
x
)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
...
@@ -189,33 +201,35 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
...
@@ -189,33 +201,35 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
return
utils
::
return
utils
::
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
cast_to_f8
<
float
,
f8_
fnuz_
t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
rng
);
x
,
rng
);
#endif
#endif
}
}
// convert fp16 to fp8 with stochastic rounding
// convert fp16 to fp8 with stochastic rounding
template
<
>
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_sr
<
f8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
// convert to float and use native converion
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_sr
<
f8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
f8_fnuz_t
,
x
,
rng
);
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
#endif
}
}
// convert fp32 to bf8 with stochastic rounding
// convert fp32 to bf8 with stochastic rounding
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_sr
<
bf8_
fnuz_
t
,
float
>
(
float
x
)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
...
@@ -240,28 +254,32 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
...
@@ -240,28 +254,32 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
return
utils
::
return
utils
::
cast_to_f8
<
float
,
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
bf8_fnuz_t
,
x
,
rng
);
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
#endif
}
}
// convert fp16 to bf8 with stochastic rounding
// convert fp16 to bf8 with stochastic rounding
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_sr
<
bf8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
// convert to float and use native converion
// convert to float and use native converion
return
f8_convert_sr
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_sr
<
bf8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
bf8_fnuz_t
,
x
,
rng
);
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
#endif
}
}
...
@@ -271,7 +289,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
...
@@ -271,7 +289,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even
// convert fp32 to fp8 with rounding to nearest even
template
<
>
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_rne
<
f8_
fnuz_
t
,
float
>
(
float
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
union
union
...
@@ -296,32 +314,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
...
@@ -296,32 +314,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
return
utils
::
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
cast_to_f8
<
float
,
f8_
fnuz_
t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
rng
);
x
,
rng
);
#endif
#endif
}
}
// convert fp16 to fp8 with rounding to nearest even
// convert fp16 to fp8 with rounding to nearest even
template
<
>
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_rne
<
f8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
// convert to float and use native converion
// convert to float and use native converion
return
f8_convert_rne
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_rne
<
f8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
f8_fnuz_t
,
x
,
rng
);
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
#endif
}
}
// convert fp32 to bf8 with rounding to nearest even
// convert fp32 to bf8 with rounding to nearest even
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_rne
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_rne
<
bf8_
fnuz_
t
,
float
>
(
float
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
union
union
...
@@ -345,44 +365,59 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
...
@@ -345,44 +365,59 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
return
utils
::
cast_to_f8
<
float
,
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
bf8_fnuz_t
,
x
,
rng
);
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
#endif
}
}
// convert fp16 to bf8 with rounding to nearest even
// convert fp16 to bf8 with rounding to nearest even
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_rne
<
bf8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_rne
<
bf8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
// convert to float and use native converion
// convert to float and use native converion
return
f8_convert_rne
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_rne
<
bf8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
bf8_fnuz_t
,
x
,
rng
);
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp32 to fp8
template
<
>
inline
__host__
__device__
f8_fnuz_t
type_convert
<
f8_fnuz_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_fnuz_t
>
(
x
);
#else
return
f8_convert_rne
<
f8_fnuz_t
>
(
x
);
#endif
#endif
}
}
// convert fp32 to fp8
// convert fp32 to fp8
template
<
>
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_
ocp_
t
type_convert
<
f8_
ocp_
t
,
float
>
(
float
x
)
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_t
>
(
x
);
return
f8_convert_sr
<
f8_
ocp_
t
>
(
x
);
#else
#else
return
f8_convert_rne
<
f8_t
>
(
x
);
return
f8_convert_rne
<
f8_
ocp_
t
>
(
x
);
#endif
#endif
}
}
// convert fp8 to fp32
// convert fp8 to fp32
template
<
>
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f8_
t
>
(
f8
_t
x
)
inline
__host__
__device__
float
type_convert
<
float
,
f8_
fnuz_t
>
(
f8_fnuz
_t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
float
fval
;
float
fval
;
...
@@ -392,26 +427,26 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
...
@@ -392,26 +427,26 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
return
fval
;
return
fval
;
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
float
,
negative_zero_nan
>
(
x
);
#endif
#endif
}
}
template
<
>
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f8x2_t
>
(
f8x2_t
x
)
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f8x2_
fnuz_
t
>
(
f8x2_
fnuz_
t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
const
auto
i16val
=
bit_cast
<
uint16_t
>
(
x
);
const
auto
i16val
=
bit_cast
<
uint16_t
>
(
x
);
return
__builtin_amdgcn_cvt_pk_f32_fp8
(
i16val
,
0
);
return
__builtin_amdgcn_cvt_pk_f32_fp8
(
i16val
,
0
);
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
const
auto
f8x2_v
=
vector_type
<
f8_t
,
2
>
(
x
);
const
auto
f8x2_v
=
vector_type
<
f8_
fnuz_
t
,
2
>
(
x
);
vector_type
<
float
,
2
>
f32x2_v
;
vector_type
<
float
,
2
>
f32x2_v
;
f32x2_v
.
template
AsType
<
float
>()(
Number
<
0
>
{})
=
f32x2_v
.
template
AsType
<
float
>()(
Number
<
0
>
{})
=
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_t
>()[
Number
<
0
>
{}]);
f8x2_v
.
template
AsType
<
f8_
fnuz_
t
>()[
Number
<
0
>
{}]);
f32x2_v
.
template
AsType
<
float
>()(
Number
<
1
>
{})
=
f32x2_v
.
template
AsType
<
float
>()(
Number
<
1
>
{})
=
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_t
>()[
Number
<
1
>
{}]);
f8x2_v
.
template
AsType
<
f8_
fnuz_
t
>()[
Number
<
1
>
{}]);
return
f32x2_v
.
template
AsType
<
float2_t
>()[
Number
<
0
>
{}];
return
f32x2_v
.
template
AsType
<
float2_t
>()[
Number
<
0
>
{}];
#endif
#endif
}
}
...
@@ -428,42 +463,64 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
...
@@ -428,42 +463,64 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
// convert fp16 to fp8
// convert fp16 to fp8
template
<
>
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_fnuz_t
type_convert
<
f8_fnuz_t
,
half_t
>
(
half_t
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_fnuz_t
>
(
x
);
#else
return
f8_convert_rne
<
f8_fnuz_t
>
(
x
);
#endif
}
// convert fp16 to fp8
template
<
>
inline
__host__
__device__
f8_ocp_t
type_convert
<
f8_ocp_t
,
half_t
>
(
half_t
x
)
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_t
>
(
x
);
return
f8_convert_sr
<
f8_
ocp_
t
>
(
x
);
#else
#else
return
f8_convert_rne
<
f8_t
>
(
x
);
return
f8_convert_rne
<
f8_
ocp_
t
>
(
x
);
#endif
#endif
}
}
// convert fp8 to fp16
// convert fp8 to fp16
template
<
>
template
<
>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_
t
>
(
f8
_t
x
)
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_
fnuz_t
>
(
f8_fnuz
_t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
half_t
,
negative_zero_nan
>
(
x
);
#endif
#endif
}
}
// convert fp32 to bf8
// convert fp32 to bf8
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_
fnuz_
t
type_convert
<
bf8_
fnuz_
t
,
float
>
(
float
x
)
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_t
>
(
x
);
return
f8_convert_sr
<
bf8_
fnuz_
t
>
(
x
);
#else
#else
return
f8_convert_rne
<
bf8_t
>
(
x
);
return
f8_convert_rne
<
bf8_fnuz_t
>
(
x
);
#endif
}
// convert fp32 to bf8
template
<
>
inline
__host__
__device__
bf8_ocp_t
type_convert
<
bf8_ocp_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_ocp_t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_ocp_t
>
(
x
);
#endif
#endif
}
}
// convert bf8 to fp32
// convert bf8 to fp32
template
<
>
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
bf8_t
>
(
bf8_t
x
)
inline
__host__
__device__
float
type_convert
<
float
,
bf8_
fnuz_
t
>
(
bf8_
fnuz_
t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
float
fval
;
float
fval
;
...
@@ -473,31 +530,42 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
...
@@ -473,31 +530,42 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
return
fval
;
return
fval
;
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
float
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
bf8_fnuz_t
,
float
,
negative_zero_nan
>
(
x
);
#endif
}
// convert fp16 to bf8
template
<
>
inline
__host__
__device__
bf8_fnuz_t
type_convert
<
bf8_fnuz_t
,
half_t
>
(
half_t
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_fnuz_t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_fnuz_t
>
(
x
);
#endif
#endif
}
}
// convert fp16 to bf8
// convert fp16 to bf8
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
bf8_
ocp_
t
type_convert
<
bf8_
ocp_
t
,
half_t
>
(
half_t
x
)
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_t
>
(
x
);
return
f8_convert_sr
<
bf8_
ocp_
t
>
(
x
);
#else
#else
return
f8_convert_rne
<
bf8_t
>
(
x
);
return
f8_convert_rne
<
bf8_
ocp_
t
>
(
x
);
#endif
#endif
}
}
// convert bf8 to fp16
// convert bf8 to fp16
template
<
>
template
<
>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
bf8_t
>
(
bf8_t
x
)
inline
__host__
__device__
half_t
type_convert
<
half_t
,
bf8_
fnuz_
t
>
(
bf8_
fnuz_
t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
bf8_
fnuz_
t
,
half_t
,
negative_zero_nan
>
(
x
);
#endif
#endif
}
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
1504c3e8
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -62,9 +62,9 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -62,9 +62,9 @@ struct ReferenceGemm : public device::BaseOperator
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
const
int
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
AccDataType
v_acc
=
0
;
AccDataType
v_acc
{
0
}
;
ComputeTypeA
v_a
=
0
;
ComputeTypeA
v_a
{
0
}
;
ComputeTypeB
v_b
=
0
;
ComputeTypeB
v_b
{
0
}
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
...
@@ -93,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -93,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
}
CDataType
v_c
=
0
;
CDataType
v_c
{
0
}
;
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_element_op_
(
v_c
,
v_acc
);
...
...
library/include/ck/library/utility/host_tensor.hpp
View file @
1504c3e8
...
@@ -326,7 +326,7 @@ struct Tensor
...
@@ -326,7 +326,7 @@ struct Tensor
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
void
SetZero
()
{
ck
::
ranges
::
fill
<
T
>
(
mData
,
0
);
}
void
SetZero
()
{
ck
::
ranges
::
fill
<
T
>
(
mData
,
T
{
0
}
);
}
template
<
typename
F
>
template
<
typename
F
>
void
ForEach_impl
(
F
&&
f
,
std
::
vector
<
size_t
>&
idx
,
size_t
rank
)
void
ForEach_impl
(
F
&&
f
,
std
::
vector
<
size_t
>&
idx
,
size_t
rank
)
...
...
library/include/ck/library/utility/host_tensor_generator.hpp
View file @
1504c3e8
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -37,7 +37,7 @@ struct GeneratorTensor_1<ck::half_t>
...
@@ -37,7 +37,7 @@ struct GeneratorTensor_1<ck::half_t>
float
value
=
1.0
;
float
value
=
1.0
;
template
<
typename
...
Is
>
template
<
typename
...
Is
>
ck
::
b
half_t
operator
()(
Is
...)
ck
::
half_t
operator
()(
Is
...)
{
{
return
ck
::
type_convert
<
ck
::
half_t
>
(
value
);
return
ck
::
type_convert
<
ck
::
half_t
>
(
value
);
}
}
...
@@ -62,7 +62,7 @@ struct GeneratorTensor_1<ck::f8_t>
...
@@ -62,7 +62,7 @@ struct GeneratorTensor_1<ck::f8_t>
float
value
=
1.0
;
float
value
=
1.0
;
template
<
typename
...
Is
>
template
<
typename
...
Is
>
ck
::
bhal
f_t
operator
()(
Is
...)
ck
::
f
8
_t
operator
()(
Is
...)
{
{
return
ck
::
type_convert
<
ck
::
f8_t
>
(
value
);
return
ck
::
type_convert
<
ck
::
f8_t
>
(
value
);
}
}
...
@@ -256,14 +256,33 @@ struct GeneratorTensor_Checkboard
...
@@ -256,14 +256,33 @@ struct GeneratorTensor_Checkboard
}
}
};
};
template
<
ck
::
index_t
Dim
>
/**
* @brief Is used to generate sequential values based on the specified dimension.
*
* @tparam T The type of the tensor values.
* @tparam Dim The specific dimension used for generation.
*
* GeneratorTensor_Sequential<1>{} will generate the following values for a 3x3 tensor:
*
* 0 1 2
* 0 1 2
* 0 1 2
*
* Essentially, the values generated are logical coordinates of the generated element that
* correspond to dimension Dim. E.g. for 2-dimensional tensor and Dim=1, the values are the column
* indices.
*
*/
template
<
typename
T
,
ck
::
index_t
Dim
>
struct
GeneratorTensor_Sequential
struct
GeneratorTensor_Sequential
{
{
template
<
typename
...
Ts
>
template
<
typename
...
Ts
>
float
operator
()(
Ts
...
Xs
)
const
T
operator
()(
Ts
...
Xs
)
const
{
{
std
::
array
<
ck
::
index_t
,
sizeof
...(
Ts
)
>
dims
=
{{
static_cast
<
ck
::
index_t
>
(
Xs
)...}};
std
::
array
<
ck
::
index_t
,
sizeof
...(
Ts
)
>
dims
=
{{
static_cast
<
ck
::
index_t
>
(
Xs
)...}};
return
dims
[
Dim
];
float
tmp
=
dims
[
Dim
];
return
ck
::
type_convert
<
T
>
(
tmp
);
}
}
};
};
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
1504c3e8
...
@@ -70,13 +70,13 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -70,13 +70,13 @@ function(add_instance_library INSTANCE_NAME)
# Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94
# Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94
if
(
NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH
)
if
(
NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH
)
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"gemm_multiply_multiply_xdl_f8"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND
NOT INST_TARGETS MATCHES
"gfx95"
AND
source MATCHES
"gemm_multiply_multiply_xdl_f8"
)
message
(
"removing gemm_multiply_multiply_f8 instance
${
source
}
"
)
message
(
"removing gemm_multiply_multiply_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"gemm_xdl_universal"
AND source MATCHES
"_f8_"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND
NOT INST_TARGETS MATCHES
"gfx95"
AND
source MATCHES
"gemm_xdl_universal"
AND source MATCHES
"_f8_"
)
message
(
"removing gemm_universal_f8 instance
${
source
}
"
)
message
(
"removing gemm_universal_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
...
...
library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f8_instance.cpp
View file @
1504c3e8
...
@@ -15,7 +15,7 @@ void add_device_pool3d_fwd_ndhwc_f8_instances(
...
@@ -15,7 +15,7 @@ void add_device_pool3d_fwd_ndhwc_f8_instances(
instances
)
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
device_pool3d_fwd_ndhwc_instances
<
F8
,
F8
,
I32
,
F
8
,
ReduceOpId
,
false
>
{});
instances
,
device_pool3d_fwd_ndhwc_instances
<
F8
,
F8
,
I32
,
F
32
,
ReduceOpId
,
false
>
{});
}
}
void
add_device_pool3d_fwd_ndhwc_index_f8_instances
(
void
add_device_pool3d_fwd_ndhwc_index_f8_instances
(
...
@@ -23,7 +23,7 @@ void add_device_pool3d_fwd_ndhwc_index_f8_instances(
...
@@ -23,7 +23,7 @@ void add_device_pool3d_fwd_ndhwc_index_f8_instances(
instances
)
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
device_pool3d_fwd_ndhwc_instances
<
F8
,
F8
,
I32
,
F
8
,
ReduceOpId
,
true
>
{});
instances
,
device_pool3d_fwd_ndhwc_instances
<
F8
,
F8
,
I32
,
F
32
,
ReduceOpId
,
true
>
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
View file @
1504c3e8
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -150,7 +150,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
...
@@ -150,7 +150,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
break
;
break
;
default:
default:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{
1
});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{
1
});
}
}
...
...
profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp
View file @
1504c3e8
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -157,7 +157,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
...
@@ -157,7 +157,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
break
;
break
;
default:
default:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
}
...
...
profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
View file @
1504c3e8
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -174,7 +174,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
...
@@ -174,7 +174,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
break
;
break
;
default:
default:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
}
...
...
profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
View file @
1504c3e8
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -140,7 +140,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
...
@@ -140,7 +140,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
break
;
break
;
default:
default:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
}
...
...
profiler/include/profiler/profile_gemm_impl.hpp
View file @
1504c3e8
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -74,8 +74,8 @@ int profile_gemm_impl(int do_verification,
...
@@ -74,8 +74,8 @@ int profile_gemm_impl(int do_verification,
switch
(
init_method
)
switch
(
init_method
)
{
{
case
0
:
case
0
:
ck
::
utils
::
FillConstant
<
ADataType
>
{
static_cas
t
<
ADataType
>
(
1.
f
)}(
a_m_k
);
ck
::
utils
::
FillConstant
<
ADataType
>
{
type_conver
t
<
ADataType
>
(
1.
f
)}(
a_m_k
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
static_cas
t
<
BDataType
>
(
1.
f
)}(
b_k_n
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
type_conver
t
<
BDataType
>
(
1.
f
)}(
b_k_n
);
break
;
break
;
case
1
:
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
...
...
test/CMakeLists.txt
View file @
1504c3e8
...
@@ -206,7 +206,7 @@ add_subdirectory(wrapper)
...
@@ -206,7 +206,7 @@ add_subdirectory(wrapper)
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx11"
)
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
add_subdirectory
(
wmma_op
)
endif
()
endif
()
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx942"
AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2
)
# smfmac needs ROCm6.2
if
(
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx942"
OR SUPPORTED_GPU_TARGETS MATCHES
"gfx95"
)
AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2
)
# smfmac needs ROCm6.2
add_subdirectory
(
smfmac_op
)
add_subdirectory
(
smfmac_op
)
endif
()
endif
()
add_subdirectory
(
position_embedding
)
add_subdirectory
(
position_embedding
)
...
...
test/data_type/CMakeLists.txt
View file @
1504c3e8
...
@@ -9,13 +9,38 @@ if (USE_BITINT_EXTENSION_INT4)
...
@@ -9,13 +9,38 @@ if (USE_BITINT_EXTENSION_INT4)
endif
()
endif
()
endif
()
endif
()
add_gtest_executable
(
test_fp8 test_fp8.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_fp8 PRIVATE utility
)
add_custom_target
(
test_fp8
)
if
(
CK_USE_OCP_FP8
)
add_gtest_executable
(
test_fp8_ocp test_fp8_ocp.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_fp8_ocp PRIVATE utility
)
endif
()
add_gtest_executable
(
test_bf8_ocp test_bf8_ocp.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_bf8_ocp PRIVATE utility
)
endif
()
add_dependencies
(
test_fp8 test_fp8_ocp
)
add_dependencies
(
test_fp8 test_bf8_ocp
)
endif
()
endif
()
add_gtest_executable
(
test_bf8 test_bf8.cpp
)
if
(
result EQUAL 0
)
if
(
CK_USE_FNUZ_FP8
)
target_link_libraries
(
test_bf8 PRIVATE utility
)
add_gtest_executable
(
test_fp8_fnuz test_fp8_fnuz.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_fp8_fnuz PRIVATE utility
)
endif
()
add_gtest_executable
(
test_bf8_fnuz test_bf8_fnuz.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_bf8_fnuz PRIVATE utility
)
endif
()
add_dependencies
(
test_fp8 test_fp8_fnuz
)
add_dependencies
(
test_fp8 test_bf8_fnuz
)
endif
()
endif
()
add_gtest_executable
(
test_custom_type test_custom_type.cpp
)
add_gtest_executable
(
test_custom_type test_custom_type.cpp
)
...
...
test/data_type/test_bf8.cpp
→
test/data_type/test_bf8
_fnuz
.cpp
View file @
1504c3e8
...
@@ -5,158 +5,169 @@
...
@@ -5,158 +5,169 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
bf8_t
;
using
ck
::
bf8_
fnuz_
t
;
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_sr
;
using
ck
::
f8_convert_sr
;
using
ck
::
half_t
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
using
ck
::
type_convert
;
TEST
(
BF8
,
NumericLimits
)
TEST
(
BF8
FNUZ
,
NumericLimits
)
{
{
// constants given for negative zero nan mode
// constants given for negative zero nan mode
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
Min
(),
type_convert
<
bf8_t
>
(
0x04
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_
fnuz_
t
>::
Min
(),
type_convert
<
bf8_
fnuz_
t
>
(
0x04
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
Max
(),
type_convert
<
bf8_t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_
fnuz_
t
>::
Max
(),
type_convert
<
bf8_
fnuz_
t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
Lowest
(),
type_convert
<
bf8_t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_
fnuz_
t
>::
Lowest
(),
type_convert
<
bf8_
fnuz_
t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
QuietNaN
(),
type_convert
<
bf8_t
>
(
0x80
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_
fnuz_
t
>::
QuietNaN
(),
type_convert
<
bf8_
fnuz_
t
>
(
0x80
));
}
}
TEST
(
BF8
,
ConvertFP32Nearest
)
TEST
(
BF8
FNUZ
,
ConvertFP32Nearest
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-6
;
float
abs_tol
=
1e-6
;
// convert 0 float to bf8 and back, check if holds
// convert 0 float to bf8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
0.0
f
)),
abs_tol
);
// don't run the next test on gfx11 devices
// don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST
#ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to bf8 and back, check if holds
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
abs_tol
);
#endif
#endif
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR
(
57344.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
57344.0
f
)),
abs_tol
);
const
auto
max_bf8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
Max
());
// convert maximal bf8_fnuz_t to float and check if equal to 57344.0
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_fnuz_t
>
(
max_bf8_t_float
)),
abs_tol
);
// convert maximal float to bf8 and back, check if clipped to 57344.0
// convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
57344.0
f
,
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
abs_tol
);
// convert inf float to bf8_t and check if it is qNan
// convert inf float to bf8_
fnuz_
t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
QuietNaN
(
),
f8_convert_rne
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
f8_convert_rne
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
abs_tol
);
// positive norm float value to bf8 and back, check if holds
// positive norm float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
float
pos_float
=
0.0000762939
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to bf8 and back, check if holds
// negative norm float value to bf8 and back, check if holds
float
neg_float
=
-
0.0000610351
f
;
float
neg_float
=
-
0.0000610351
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to bf8 and back, check if holds
// positive subnorm float value to bf8 and back, check if holds
pos_float
=
0.0000305175
f
;
pos_float
=
0.0000305175
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to bf8 and back, check if holds
// negative subnorm float value to bf8 and back, check if holds
neg_float
=
-
0.0000152587
f
;
neg_float
=
-
0.0000152587
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
}
}
TEST
(
BF8
,
ConvertFP32Stochastic
)
TEST
(
BF8
FNUZ
,
ConvertFP32Stochastic
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-6
;
float
abs_tol
=
1e-6
;
// convert 0 float to bf8 and back, check if holds
// convert 0 float to bf8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
0.0
f
)),
abs_tol
);
// convert minimal float to bf8 and back, check if holds
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
abs_tol
);
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR
(
57344.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
57344.0
f
)),
abs_tol
);
const
auto
max_bf8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
Max
());
// convert maximal bf8_fnuz_t to float and check if equal to 57344.0
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_fnuz_t
>
(
max_bf8_t_float
)),
abs_tol
);
// convert maximal float to bf8 and back, check if clipped to 57344.0
// convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
57344.0
f
,
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
abs_tol
);
// convert inf float to bf8_t and check if it is qNan
// convert inf float to bf8_
fnuz_
t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
QuietNaN
(
),
f8_convert_sr
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
f8_convert_sr
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
abs_tol
);
// positive norm float value to bf8 and back, check if holds
// positive norm float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
float
pos_float
=
0.0000762939
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to bf8 and back, check if holds
// negative norm float value to bf8 and back, check if holds
float
neg_float
=
-
0.0000610351
f
;
float
neg_float
=
-
0.0000610351
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to bf8 and back, check if holds
// positive subnorm float value to bf8 and back, check if holds
pos_float
=
0.0000305175
f
;
pos_float
=
0.0000305175
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to bf8 and back, check if holds
// negative subnorm float value to bf8 and back, check if holds
neg_float
=
-
0.0000152587
f
;
neg_float
=
-
0.0000152587
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
}
}
TEST
(
BF8
,
ConvertFP16Nearest
)
TEST
(
BF8
FNUZ
,
ConvertFP16Nearest
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-3
;
float
abs_tol
=
1e-3
;
// convert 0 fp16 to bf8 and back, check if holds
// convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_fnuz_t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to bf8 and back, check if holds
// convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
abs_tol
);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
const
auto
max_bf8_t_half
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
Max
());
// convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0
ASSERT_NEAR
(
ASSERT_NEAR
(
half_t
{
57344.0
}
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
t
>
(
half_t
{
57344.0
}
)),
abs_tol
);
max_bf8_t_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_t
>
(
max_bf8_t_half
)),
abs_tol
);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
half_t
{
57344.0
}
,
ASSERT_NEAR
(
max_bf8_t_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
abs_tol
);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
// convert QuietNaN fp16 to bf8_
fnuz_
t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
QuietNaN
(
),
f8_convert_rne
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
f8_convert_rne
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
abs_tol
);
// positive norm fp16 value to bf8 and back, check if holds
// positive norm fp16 value to bf8 and back, check if holds
half_t
pos_half
=
half_t
{
0.0000762939
};
half_t
pos_half
=
half_t
{
0.0000762939
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to bf8 and back, check if holds
// negative norm fp16 value to bf8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.0000610351
};
half_t
neg_half
=
half_t
{
-
0.0000610351
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to bf8 and back, check if holds
// positive subnorm fp16 value to bf8 and back, check if holds
pos_half
=
half_t
{
0.0000305175
};
pos_half
=
half_t
{
0.0000305175
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to bf8 and back, check if holds
// negative subnorm fp16 value to bf8 and back, check if holds
neg_half
=
half_t
{
-
0.0000152587
};
neg_half
=
half_t
{
-
0.0000152587
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
}
}
TEST
(
BF8
,
ConvertFP16Stochastic
)
TEST
(
BF8
FNUZ
,
ConvertFP16Stochastic
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-3
;
float
abs_tol
=
1e-3
;
// convert 0 fp16 to bf8 and back, check if holds
// convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to bf8 and back, check if holds
// convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
abs_tol
);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
const
auto
max_bf8_t_half
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
Max
());
// convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0
ASSERT_NEAR
(
ASSERT_NEAR
(
half_t
{
57344.0
}
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
t
>
(
half_t
{
57344.0
}
)),
abs_tol
);
max_bf8_t_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_t
>
(
max_bf8_t_half
)),
abs_tol
);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
half_t
{
57344.0
}
,
ASSERT_NEAR
(
max_bf8_t_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
abs_tol
);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
// convert QuietNaN fp16 to bf8_
fnuz_
t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
QuietNaN
(
),
f8_convert_sr
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
f8_convert_sr
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
abs_tol
);
// positive norm fp16 value to bf8 and back, check if holds
// positive norm fp16 value to bf8 and back, check if holds
half_t
pos_half
=
half_t
{
0.0000762939
};
half_t
pos_half
=
half_t
{
0.0000762939
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to bf8 and back, check if holds
// negative norm fp16 value to bf8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.0000610351
};
half_t
neg_half
=
half_t
{
-
0.0000610351
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to bf8 and back, check if holds
// positive subnorm fp16 value to bf8 and back, check if holds
pos_half
=
half_t
{
0.0000305175
};
pos_half
=
half_t
{
0.0000305175
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to bf8 and back, check if holds
// negative subnorm fp16 value to bf8 and back, check if holds
neg_half
=
half_t
{
-
0.0000152587
};
neg_half
=
half_t
{
-
0.0000152587
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
}
}
test/data_type/test_bf8_ocp.cpp
0 → 100644
View file @
1504c3e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
bf8_ocp_t
;
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_sr
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
TEST
(
BF8OCP
,
NumericLimits
)
{
// constants given for OCP FP8
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Min
(),
type_convert
<
bf8_ocp_t
>
(
0x04
));
// 0b00000100 = 2^-14
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
type_convert
<
bf8_ocp_t
>
(
0x7B
));
// 0b01111011 = 57344
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Lowest
(),
type_convert
<
bf8_ocp_t
>
(
0xFB
));
// 0b11111011 = -57344
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
QuietNaN
().
data
,
type_convert
<
bf8_ocp_t
>
(
0x7D
).
data
);
// 0b01111101
EXPECT_FALSE
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
QuietNaN
()
==
ck
::
NumericLimits
<
bf8_ocp_t
>::
QuietNaN
());
EXPECT_TRUE
(
ck
::
fp8_is_inf
(
type_convert
<
bf8_ocp_t
>
(
0xFC
))
&&
ck
::
fp8_is_inf
(
type_convert
<
bf8_ocp_t
>
(
0x7C
)));
}
TEST
(
BF8OCP
,
ConvertFP32Nearest
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to bfp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
0.0
f
)),
0.0
f
);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
const
auto
max_bf8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
());
// convert maximal bf8_ocp_t to float and check if equal to bf8 max
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
max_bf8_t_float
)),
0.0
f
);
// convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
0.0
f
);
// convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
f8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()));
// positive normal float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
// 10*2^-17
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
pos_float
)),
abs_tol
);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr
auto
neg_min_bf8
=
-
0.00006103515625
f
;
//-2^-14
ASSERT_NEAR
(
neg_min_bf8
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
neg_min_bf8
)),
0.0
f
);
// positive subnorm float value to bf8 and back, check if holds
constexpr
auto
pos_subnorm_bf8
=
0.000030517578125
f
;
// 2^-15
ASSERT_NEAR
(
pos_subnorm_bf8
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
pos_subnorm_bf8
)),
0.0
f
);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr
auto
min_subnorm_bf8
=
-
0.0000152587890625
f
;
//-2^-16
ASSERT_NEAR
(
min_subnorm_bf8
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
min_subnorm_bf8
)),
0.0
f
);
// smaller than min subnorm bf8 value to bf8 must be zero
constexpr
auto
less_than_min_subnorm
=
0.00000762939453125
f
;
// 2^-17
ASSERT_EQ
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
less_than_min_subnorm
)));
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const
auto
bf8_nan
=
f8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
quiet_NaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_bf8_is_nan
(
bf8_nan
.
data
));
}
TEST
(
BF8OCP
,
ConvertFP32Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to bfp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
0.0
f
)),
0.0
f
);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
const
auto
max_bf8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
());
// convert maximal bf8_ocp_t to float and check if equal to bf8 max
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
max_bf8_t_float
)),
0.0
f
);
// convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
0.0
f
);
// convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
f8_convert_sr
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()));
// positive normal float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
// 10*2^-17
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
pos_float
)),
abs_tol
);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr
auto
neg_min_bf8
=
-
0.00006103515625
f
;
//-2^-14
ASSERT_NEAR
(
neg_min_bf8
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
neg_min_bf8
)),
0.0
f
);
// positive subnorm float value to bf8 and back, check if holds
constexpr
auto
pos_subnorm_bf8
=
0.000030517578125
f
;
// 2^-15
ASSERT_NEAR
(
pos_subnorm_bf8
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
pos_subnorm_bf8
)),
0.0
f
);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr
auto
min_subnorm_bf8
=
-
0.0000152587890625
f
;
//-2^-16
ASSERT_NEAR
(
min_subnorm_bf8
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
min_subnorm_bf8
)),
0.0
f
);
// smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16
constexpr
auto
less_than_min_subnorm
=
0.00000762939453125
f
;
// 2^-17
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
less_than_min_subnorm
)),
0.0000152587890625
f
);
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const
auto
bf8_nan
=
f8_convert_sr
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
quiet_NaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_bf8_is_nan
(
bf8_nan
.
data
));
}
TEST
(
BF8OCP
,
ConvertFP16Nearest
)
{
// fix the tolerance value
constexpr
half_t
half_t_tol
=
1e-3
;
constexpr
half_t
half_t_zero
=
0.0
;
// convert 0 half_t to bfp8 and back, check if holds
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
half_t_zero
)),
half_t_zero
);
// convert minimal half_t to bf8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
half_t_tol
);
const
auto
max_bf8_t_half_t
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
());
// convert maximal bf8_ocp_t to half_t and check if equal to bf8 max
ASSERT_NEAR
(
max_bf8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
max_bf8_t_half_t
)),
half_t_zero
);
// convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR
(
max_bf8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
half_t_zero
);
// convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
f8_convert_rne
<
bf8_ocp_t
>
(
type_convert
<
half_t
>
(
std
::
numeric_limits
<
float
>::
infinity
())));
// positive normal bf8 value to bf8 and back, check if holds
constexpr
half_t
pos_norm_bf8
{
0.0000762939
f
};
// 10*2^-17
ASSERT_NEAR
(
pos_norm_bf8
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
pos_norm_bf8
)),
half_t_tol
);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr
half_t
neg_min_bf8
{
-
0.00006103515625
f
};
//-2^-14
ASSERT_NEAR
(
neg_min_bf8
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
neg_min_bf8
)),
half_t_zero
);
// positive subnorm bf8 value to bf8 and back, check if holds
constexpr
half_t
pos_subnorm_bf8
{
0.000030517578125
f
};
// 2^-15
ASSERT_NEAR
(
pos_subnorm_bf8
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
pos_subnorm_bf8
)),
half_t_zero
);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr
half_t
min_subnorm_bf8
{
-
0.0000152587890625
f
};
//-2^-16
ASSERT_NEAR
(
min_subnorm_bf8
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
min_subnorm_bf8
)),
half_t_zero
);
// smaller than min subnorm bf8 value to bf8 must be zero
constexpr
half_t
less_than_min_subnorm
{
0.00000762939453125
f
};
// 2^-17
ASSERT_EQ
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
less_than_min_subnorm
)));
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const
auto
bf8_nan
=
f8_convert_rne
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_bf8_is_nan
(
bf8_nan
.
data
));
}
TEST
(
BF8OCP
,
ConvertFP16Stochastic
)
{
// fix the tolerance value
constexpr
half_t
half_t_tol
=
1e-3
;
constexpr
half_t
half_t_zero
=
0.0
;
constexpr
auto
min_subnorm_bf8
=
0.0000152587890625
f
;
// 2^-16
// convert 0 half_t to bfp8 and back, check if holds
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
half_t_zero
)),
half_t_zero
);
// convert minimal half_t (6.103515625e-05) to fp8 and back
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
half_t_zero
);
const
auto
max_bf8_t_half_t
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
());
// convert maximal bf8_ocp_t to half_t and check if equal to bf8 max
ASSERT_NEAR
(
max_bf8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
max_bf8_t_half_t
)),
half_t_zero
);
// convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR
(
max_bf8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
half_t_zero
);
// convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
f8_convert_sr
<
bf8_ocp_t
>
(
type_convert
<
half_t
>
(
std
::
numeric_limits
<
float
>::
infinity
())));
// positive normal bf8 value to bf8 and back, check if holds
constexpr
half_t
pos_norm_bf8
{
0.0000762939
f
};
// 10*2^-17
ASSERT_NEAR
(
pos_norm_bf8
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
pos_norm_bf8
)),
half_t_tol
);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr
half_t
neg_min_bf8
{
-
0.00006103515625
f
};
//-2^-14
ASSERT_NEAR
(
neg_min_bf8
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
neg_min_bf8
)),
half_t_zero
);
// positive subnorm bf8 value to bf8 and back, check if holds
constexpr
half_t
pos_subnorm_bf8
{
0.000030517578125
f
};
// 2^-15
ASSERT_NEAR
(
pos_subnorm_bf8
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
pos_subnorm_bf8
)),
half_t_zero
);
// min subnorm bf8 value to bf8 and back, check if holds
ASSERT_NEAR
(
half_t
{
-
min_subnorm_bf8
},
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
half_t
{
-
min_subnorm_bf8
})),
half_t_zero
);
// smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16
constexpr
half_t
less_than_min_subnorm
{
0.00000762939453125
f
};
// 2^-17
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
less_than_min_subnorm
)),
half_t
{
min_subnorm_bf8
});
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const
auto
bf8_nan
=
f8_convert_sr
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_bf8_is_nan
(
bf8_nan
.
data
));
}
test/data_type/test_custom_type.cpp
View file @
1504c3e8
...
@@ -872,3 +872,153 @@ TEST(Complex_half, TestAsTypeReshape)
...
@@ -872,3 +872,153 @@ TEST(Complex_half, TestAsTypeReshape)
test_vec
.
at
(
num_elem
*
i
+
1
));
test_vec
.
at
(
num_elem
*
i
+
1
));
});
});
}
}
#if CK_USE_OCP_FP8
TEST
(
FP8OCP
,
TestSize
)
{
static_assert
(
std
::
is_same_v
<
f8_t
,
ck
::
f8_ocp_t
>
,
"OCP FP8 is not enabled"
);
ASSERT_EQ
(
sizeof
(
f8_t
),
sizeof
(
ck
::
fp8_storage_t
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
2
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
4
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
8
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
16
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
32
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
64
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
64
>
));
}
TEST
(
FP8OCP
,
TestAsType
)
{
static_assert
(
std
::
is_same_v
<
f8_t
,
ck
::
f8_ocp_t
>
,
"OCP FP8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
4
,
-
2
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
f8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
f8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
));
});
// copy the vector
vector_type
<
f8_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
)));
});
}
TEST
(
FP8OCP
,
TestAsTypeReshape
)
{
static_assert
(
std
::
is_same_v
<
f8_t
,
ck
::
f8_ocp_t
>
,
"OCP FP8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
8
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
/
256
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
f8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
f8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
));
});
// copy the first half of a vector
vector_type
<
f8_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
f8_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
)));
});
}
TEST
(
BF8OCP
,
TestSize
)
{
static_assert
(
std
::
is_same_v
<
bf8_t
,
ck
::
bf8_ocp_t
>
,
"OCP BF8 is not enabled"
);
ASSERT_EQ
(
sizeof
(
bf8_t
),
sizeof
(
ck
::
fp8_storage_t
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
2
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
4
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
8
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
16
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
32
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
64
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
64
>
));
}
TEST
(
BF8OCP
,
TestAsType
)
{
static_assert
(
std
::
is_same_v
<
bf8_t
,
ck
::
bf8_ocp_t
>
,
"OCP BF8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
4
,
-
2
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
bf8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
bf8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
));
});
// copy the vector
vector_type
<
bf8_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
)));
});
}
TEST
(
BF8OCP
,
TestAsTypeReshape
)
{
static_assert
(
std
::
is_same_v
<
bf8_t
,
ck
::
bf8_ocp_t
>
,
"OCP BF8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
8
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
/
256
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
bf8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
bf8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
));
});
// copy the first half of a vector
vector_type
<
bf8_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
bf8_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
)));
});
}
#endif
test/data_type/test_fp8.cpp
→
test/data_type/test_fp8
_fnuz
.cpp
View file @
1504c3e8
...
@@ -7,154 +7,171 @@
...
@@ -7,154 +7,171 @@
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_sr
;
using
ck
::
f8_convert_sr
;
using
ck
::
f8_t
;
using
ck
::
f8_
fnuz_
t
;
using
ck
::
half_t
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
using
ck
::
type_convert
;
TEST
(
FP8
,
NumericLimits
)
TEST
(
FP8
FNUZ
,
NumericLimits
)
{
{
// constants given for negative zero nan mode
// constants given for negative zero nan mode
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Min
(),
type_convert
<
f8_t
>
(
0x08
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_
fnuz_
t
>::
Min
(),
type_convert
<
f8_
fnuz_
t
>
(
0x08
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Max
(),
type_convert
<
f8_t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_
fnuz_
t
>::
Max
(),
type_convert
<
f8_
fnuz_
t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Lowest
(),
type_convert
<
f8_t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_
fnuz_
t
>::
Lowest
(),
type_convert
<
f8_
fnuz_
t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
QuietNaN
(),
type_convert
<
f8_t
>
(
0x80
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_
fnuz_
t
>::
QuietNaN
(),
type_convert
<
f8_
fnuz_
t
>
(
0x80
));
}
}
TEST
(
FP8
,
ConvertFP32Nearest
)
TEST
(
FP8
FNUZ
,
ConvertFP32Nearest
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-6
;
float
abs_tol
=
1e-6
;
// convert 0 float to fp8 and back, check if holds
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
0.0
f
)),
abs_tol
);
// don't run the next test on gfx11 devices
// don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST
#ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to fp8 and back, check if holds
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
abs_tol
);
#endif
#endif
// convert maximal f8_t to float and check if equal to 240.0
ASSERT_NEAR
(
240.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
240.0
f
)),
abs_tol
);
const
auto
max_f8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
Max
());
// convert maximal float to fp8 and back, check if clipped to 240.0
// convert maximal f8_fnuz_t to float and check if equal to fp8 max
ASSERT_NEAR
(
240.0
f
,
ASSERT_NEAR
(
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_fnuz_t
>
(
max_f8_t_float
)),
abs_tol
);
// XXX: FNUZ f8_convert_rne behavior is inconsistent.
// Clipping large values to fp8 max (saturation to finite) contradicts converting inf float to
// fp8 qNAN (no saturation).
// convert maximal float to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_fnuz_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
abs_tol
);
// convert inf float to f8_t and check if it is qNan
// convert inf float to f8_
fnuz_
t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
QuietNaN
(
),
f8_convert_rne
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
f8_convert_rne
<
f8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
abs_tol
);
// positive norm float value to fp8 and back, check if holds
// positive norm float value to fp8 and back, check if holds
float
pos_float
=
0.017578125
f
;
float
pos_float
=
0.017578125
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to fp8 and back, check if holds
// negative norm float value to fp8 and back, check if holds
float
neg_float
=
-
0.015625
f
;
float
neg_float
=
-
0.015625
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to fp8 and back, check if holds
// positive subnorm float value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
pos_float
=
0.00390625
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to fp8 and back, check if holds
// negative subnorm float value to fp8 and back, check if holds
neg_float
=
-
0.001953125
f
;
neg_float
=
-
0.001953125
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
}
}
TEST
(
FP8
,
ConvertFP32Stochastic
)
TEST
(
FP8
FNUZ
,
ConvertFP32Stochastic
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-6
;
float
abs_tol
=
1e-6
;
// convert 0 float to fp8 and back, check if holds
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
0.0
f
)),
abs_tol
);
// convert minimal float to fp8 and back, check if holds
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
abs_tol
);
// convert maximal f8_t to float and check if equal to 240.0
ASSERT_NEAR
(
240.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
240.0
f
)),
abs_tol
);
const
auto
max_f8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
Max
());
// convert maximal float to fp8 and back, check if clipped to 240.0
// convert maximal f8_fnuz_t to float and check if equal to fp8 max
ASSERT_NEAR
(
240.0
f
,
ASSERT_NEAR
(
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_fnuz_t
>
(
max_f8_t_float
)),
abs_tol
);
// convert maximal float to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_fnuz_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
abs_tol
);
// convert inf float to f8_t and check if it is qNan
// convert inf float to f8_
fnuz_
t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
QuietNaN
(
),
f8_convert_sr
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
f8_convert_sr
<
f8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
abs_tol
);
// positive norm float value to fp8 and back, check if holds
// positive norm float value to fp8 and back, check if holds
float
pos_float
=
0.017578125
f
;
float
pos_float
=
0.017578125
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to fp8 and back, check if holds
// negative norm float value to fp8 and back, check if holds
float
neg_float
=
-
0.015625
f
;
float
neg_float
=
-
0.015625
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to fp8 and back, check if holds
// positive subnorm float value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
pos_float
=
0.00390625
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to fp8 and back, check if holds
// negative subnorm float value to fp8 and back, check if holds
neg_float
=
-
0.001953125
f
;
neg_float
=
-
0.001953125
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
}
}
TEST
(
FP8
,
ConvertFP16Nearest
)
TEST
(
FP8
FNUZ
,
ConvertFP16Nearest
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-3
;
float
abs_tol
=
1e-3
;
// convert 0 fp16 to fp8 and back, check if holds
// convert 0 fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to fp8 and back, check if holds
// convert minimal fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
abs_tol
);
// convert maximal f8_t to fp16 and check if equal to 240.0
ASSERT_NEAR
(
half_t
{
240.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
half_t
{
240.0
})),
abs_tol
);
const
auto
max_f8_t_half
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
Max
());
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
// convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max
ASSERT_NEAR
(
half_t
{
240.0
},
ASSERT_NEAR
(
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
max_f8_t_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_fnuz_t
>
(
max_f8_t_half
)),
abs_tol
);
// convert maximal fp16 to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR
(
max_f8_t_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_fnuz_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
abs_tol
);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
// convert QuietNaN fp16 to f8_
fnuz_
t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
QuietNaN
(
),
f8_convert_rne
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
f8_convert_rne
<
f8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
abs_tol
);
// positive norm fp16 value to fp8 and back, check if holds
// positive norm fp16 value to fp8 and back, check if holds
half_t
pos_half
=
half_t
{
0.017578125
};
half_t
pos_half
=
half_t
{
0.017578125
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to fp8 and back, check if holds
// negative norm fp16 value to fp8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.015625
};
half_t
neg_half
=
half_t
{
-
0.015625
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to fp8 and back, check if holds
// positive subnorm fp16 value to fp8 and back, check if holds
pos_half
=
half_t
{
0.00390625
};
pos_half
=
half_t
{
0.00390625
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to fp8 and back, check if holds
// negative subnorm fp16 value to fp8 and back, check if holds
neg_half
=
half_t
{
-
0.001953125
};
neg_half
=
half_t
{
-
0.001953125
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
}
}
TEST
(
FP8
,
ConvertFP16Stochastic
)
TEST
(
FP8
FNUZ
,
ConvertFP16Stochastic
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-3
;
float
abs_tol
=
1e-3
;
// convert 0 fp16 to fp8 and back, check if holds
// convert 0 fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to fp8 and back, check if holds
// convert minimal fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
abs_tol
);
// convert maximal f8_t to fp16 and check if equal to 240.0
ASSERT_NEAR
(
half_t
{
240.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
half_t
{
240.0
})),
abs_tol
);
const
auto
max_f8_t_half
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
Max
());
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
// convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max
ASSERT_NEAR
(
half_t
{
240.0
},
ASSERT_NEAR
(
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
max_f8_t_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_fnuz_t
>
(
max_f8_t_half
)),
abs_tol
);
// convert maximal fp16 to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR
(
max_f8_t_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_fnuz_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
abs_tol
);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
// convert QuietNaN fp16 to f8_
fnuz_
t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
QuietNaN
(
),
f8_convert_sr
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
f8_convert_sr
<
f8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
abs_tol
);
// positive norm fp16 value to fp8 and back, check if holds
// positive norm fp16 value to fp8 and back, check if holds
half_t
pos_half
=
half_t
{
0.017578125
};
half_t
pos_half
=
half_t
{
0.017578125
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to fp8 and back, check if holds
// negative norm fp16 value to fp8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.015625
};
half_t
neg_half
=
half_t
{
-
0.015625
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to fp8 and back, check if holds
// positive subnorm fp16 value to fp8 and back, check if holds
pos_half
=
half_t
{
0.00390625
};
pos_half
=
half_t
{
0.00390625
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to fp8 and back, check if holds
// negative subnorm fp16 value to fp8 and back, check if holds
neg_half
=
half_t
{
-
0.001953125
};
neg_half
=
half_t
{
-
0.001953125
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
}
}
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment