Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
965b7ba4
Unverified
Commit
965b7ba4
authored
Dec 02, 2024
by
Illia Silin
Committed by
GitHub
Dec 02, 2024
Browse files
Merge pull request #229 from ROCm/promote_ocp_fp8
Promote ocp fp8
parents
5dff1b14
62e3c582
Changes
63
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1174 additions
and
314 deletions
+1174
-314
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+360
-83
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+2
-2
include/ck/utility/random_gen.hpp
include/ck/utility/random_gen.hpp
+8
-5
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+143
-61
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+2
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+5
-5
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+1
-1
library/include/ck/library/utility/host_tensor_generator.hpp
library/include/ck/library/utility/host_tensor_generator.hpp
+25
-6
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+2
-2
library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f8_instance.cpp
...pu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f8_instance.cpp
+2
-2
profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
...r/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
+2
-2
profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp
profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp
+2
-2
profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
...clude/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
+2
-2
profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
...ofiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
+2
-2
profiler/include/profiler/profile_gemm_impl.hpp
profiler/include/profiler/profile_gemm_impl.hpp
+3
-3
test/data_type/CMakeLists.txt
test/data_type/CMakeLists.txt
+31
-6
test/data_type/test_bf8_fnuz.cpp
test/data_type/test_bf8_fnuz.cpp
+73
-62
test/data_type/test_bf8_ocp.cpp
test/data_type/test_bf8_ocp.cpp
+268
-0
test/data_type/test_custom_type.cpp
test/data_type/test_custom_type.cpp
+158
-0
test/data_type/test_fp8_fnuz.cpp
test/data_type/test_fp8_fnuz.cpp
+83
-66
No files found.
include/ck/utility/data_type.hpp
View file @
965b7ba4
This diff is collapsed.
Click to expand it.
include/ck/utility/math_v2.hpp
View file @
965b7ba4
...
@@ -80,7 +80,7 @@ static inline __host__ bool isnan(half_t x)
...
@@ -80,7 +80,7 @@ static inline __host__ bool isnan(half_t x)
return
(
xx
&
0x7FFF
)
>
0x7C00
;
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
};
static
inline
__host__
bool
isnan
(
f8_t
x
)
{
return
(
x
&
0x80
);
};
static
inline
__host__
bool
isnan
(
f8_t
x
)
{
return
ck
::
fp8_is_nan
(
x
);
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static
inline
__host__
bool
isnan
(
int4_t
x
)
static
inline
__host__
bool
isnan
(
int4_t
x
)
...
@@ -531,7 +531,7 @@ static inline __device__ bool isnan(half_t x)
...
@@ -531,7 +531,7 @@ static inline __device__ bool isnan(half_t x)
return
(
xx
&
0x7FFF
)
>
0x7C00
;
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
};
static
inline
__device__
bool
isnan
(
f8_t
x
)
{
return
(
x
&
0x80
);
};
static
inline
__device__
bool
isnan
(
f8_t
x
)
{
return
ck
::
fp8_is_nan
(
x
);
};
static
inline
__device__
half_t
sqrt
(
half_t
x
)
static
inline
__device__
half_t
sqrt
(
half_t
x
)
{
{
...
...
include/ck/utility/random_gen.hpp
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/ck.hpp"
namespace
ck
{
namespace
ck
{
// Pseudo random number generator
// Pseudo random number generator
...
@@ -23,7 +25,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
...
@@ -23,7 +25,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
}
// version for fp16
// version for fp16
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<
std
::
is_same
<
half_t
,
T
>{},
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<
std
::
is_same
<
_Float16
,
T
>{},
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
{
uint16_t
x
=
*
(
reinterpret_cast
<
uint16_t
*>
(
&
val
));
uint16_t
x
=
*
(
reinterpret_cast
<
uint16_t
*>
(
&
val
));
...
@@ -38,9 +40,10 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
...
@@ -38,9 +40,10 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
}
// return 0 if data is not fp16 or fp32
// return 0 if data is not fp16 or fp32
template
<
typename
T
,
template
<
uint32_t
seed_t
,
typename
T
,
std
::
enable_if_t
<!
(
std
::
is_same
<
float
,
T
>{}
||
std
::
is_same
<
half_t
,
T
>
{}),
bool
>
=
false
>
uint32_t
seed_t
,
std
::
enable_if_t
<!
(
std
::
is_same
<
float
,
T
>{}
||
std
::
is_same
<
_Float16
,
T
>
{}),
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
int
id
,
T
val
,
uint32_t
seed
=
seed_t
)
__host__
__device__
uint32_t
prand_generator
(
int
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
{
std
::
ignore
=
id
;
std
::
ignore
=
id
;
...
...
include/ck/utility/type_convert.hpp
View file @
965b7ba4
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
namespace
ck
{
namespace
ck
{
// Define the common macro for MI300 models
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#define __gfx94__
#endif
#endif
...
@@ -100,6 +100,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
...
@@ -100,6 +100,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
}
template
<
>
inline
__host__
__device__
constexpr
f8_ocp_t
type_convert
<
f8_ocp_t
,
int
>
(
int
x
)
{
return
f8_ocp_t
{
type_convert
<
f8_ocp_t
::
data_type
>
(
x
)};
}
template
<
>
inline
__host__
__device__
constexpr
bf8_ocp_t
type_convert
<
bf8_ocp_t
,
int
>
(
int
x
)
{
return
bf8_ocp_t
{
type_convert
<
bf8_ocp_t
::
data_type
>
(
x
)};
}
// Convert X to Y
// Convert X to Y
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert_sp
(
X
x
)
__host__
__device__
constexpr
Y
type_convert_sp
(
X
x
)
...
@@ -163,7 +175,7 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
...
@@ -163,7 +175,7 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
// convert fp32 to fp8 with stochastic rounding
template
<
>
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_sr
<
f8_
fnuz_
t
,
float
>
(
float
x
)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
...
@@ -189,33 +201,35 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
...
@@ -189,33 +201,35 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
return
utils
::
return
utils
::
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
cast_to_f8
<
float
,
f8_
fnuz_
t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
rng
);
x
,
rng
);
#endif
#endif
}
}
// convert fp16 to fp8 with stochastic rounding
// convert fp16 to fp8 with stochastic rounding
template
<
>
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_sr
<
f8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
// convert to float and use native converion
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_sr
<
f8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
f8_fnuz_t
,
x
,
rng
);
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
#endif
}
}
// convert fp32 to bf8 with stochastic rounding
// convert fp32 to bf8 with stochastic rounding
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_sr
<
bf8_
fnuz_
t
,
float
>
(
float
x
)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
...
@@ -240,28 +254,32 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
...
@@ -240,28 +254,32 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
return
utils
::
return
utils
::
cast_to_f8
<
float
,
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
bf8_fnuz_t
,
x
,
rng
);
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
#endif
}
}
// convert fp16 to bf8 with stochastic rounding
// convert fp16 to bf8 with stochastic rounding
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_sr
<
bf8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
// convert to float and use native converion
// convert to float and use native converion
return
f8_convert_sr
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_sr
<
bf8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
bf8_fnuz_t
,
x
,
rng
);
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
#endif
}
}
...
@@ -271,7 +289,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
...
@@ -271,7 +289,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even
// convert fp32 to fp8 with rounding to nearest even
template
<
>
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_rne
<
f8_
fnuz_
t
,
float
>
(
float
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
union
union
...
@@ -296,32 +314,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
...
@@ -296,32 +314,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
return
utils
::
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
cast_to_f8
<
float
,
f8_
fnuz_
t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
rng
);
x
,
rng
);
#endif
#endif
}
}
// convert fp16 to fp8 with rounding to nearest even
// convert fp16 to fp8 with rounding to nearest even
template
<
>
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_rne
<
f8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
// convert to float and use native converion
// convert to float and use native converion
return
f8_convert_rne
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_rne
<
f8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
f8_fnuz_t
,
x
,
rng
);
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
#endif
}
}
// convert fp32 to bf8 with rounding to nearest even
// convert fp32 to bf8 with rounding to nearest even
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_rne
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_rne
<
bf8_
fnuz_
t
,
float
>
(
float
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
union
union
...
@@ -345,44 +365,59 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
...
@@ -345,44 +365,59 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
return
utils
::
cast_to_f8
<
float
,
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
bf8_fnuz_t
,
x
,
rng
);
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
#endif
}
}
// convert fp16 to bf8 with rounding to nearest even
// convert fp16 to bf8 with rounding to nearest even
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_rne
<
bf8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_rne
<
bf8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
// convert to float and use native converion
// convert to float and use native converion
return
f8_convert_rne
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_rne
<
bf8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
bf8_fnuz_t
,
x
,
rng
);
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp32 to fp8
template
<
>
inline
__host__
__device__
f8_fnuz_t
type_convert
<
f8_fnuz_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_fnuz_t
>
(
x
);
#else
return
f8_convert_rne
<
f8_fnuz_t
>
(
x
);
#endif
#endif
}
}
// convert fp32 to fp8
// convert fp32 to fp8
template
<
>
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_
ocp_
t
type_convert
<
f8_
ocp_
t
,
float
>
(
float
x
)
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_t
>
(
x
);
return
f8_convert_sr
<
f8_
ocp_
t
>
(
x
);
#else
#else
return
f8_convert_rne
<
f8_t
>
(
x
);
return
f8_convert_rne
<
f8_
ocp_
t
>
(
x
);
#endif
#endif
}
}
// convert fp8 to fp32
// convert fp8 to fp32
template
<
>
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f8_
t
>
(
f8
_t
x
)
inline
__host__
__device__
float
type_convert
<
float
,
f8_
fnuz_t
>
(
f8_fnuz
_t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
float
fval
;
float
fval
;
...
@@ -392,30 +427,44 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
...
@@ -392,30 +427,44 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
return
fval
;
return
fval
;
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
float
,
negative_zero_nan
>
(
x
);
#endif
#endif
}
}
template
<
>
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f8x2_t
>
(
f8x2_t
x
)
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f8x2_
fnuz_
t
>
(
f8x2_
fnuz_
t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
const
auto
i16val
=
bit_cast
<
uint16_t
>
(
x
);
const
auto
i16val
=
bit_cast
<
uint16_t
>
(
x
);
return
__builtin_amdgcn_cvt_pk_f32_fp8
(
i16val
,
0
);
return
__builtin_amdgcn_cvt_pk_f32_fp8
(
i16val
,
0
);
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
const
auto
f8x2_v
=
vector_type
<
f8_t
,
2
>
(
x
);
const
auto
f8x2_v
=
vector_type
<
f8_
fnuz_
t
,
2
>
(
x
);
vector_type
<
float
,
2
>
f32x2_v
;
vector_type
<
float
,
2
>
f32x2_v
;
f32x2_v
.
template
AsType
<
float
>()(
Number
<
0
>
{})
=
f32x2_v
.
template
AsType
<
float
>()(
Number
<
0
>
{})
=
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_t
>()[
Number
<
0
>
{}]);
f8x2_v
.
template
AsType
<
f8_
fnuz_
t
>()[
Number
<
0
>
{}]);
f32x2_v
.
template
AsType
<
float
>()(
Number
<
1
>
{})
=
f32x2_v
.
template
AsType
<
float
>()(
Number
<
1
>
{})
=
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_t
>()[
Number
<
1
>
{}]);
f8x2_v
.
template
AsType
<
f8_
fnuz_
t
>()[
Number
<
1
>
{}]);
return
f32x2_v
.
template
AsType
<
float2_t
>()[
Number
<
0
>
{}];
return
f32x2_v
.
template
AsType
<
float2_t
>()[
Number
<
0
>
{}];
#endif
#endif
}
}
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f8x2_ocp_t
>
(
f8x2_ocp_t
x
)
{
#if CK_OCP_FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32x2_from_f8x2
<
f8_ocp_t
::
default_interpret
>
(
x
.
AsType
<
fp8_impl
::
fp8x2_storage_t
>
()[
Number
<
0
>
{}]);
#else
return
float2_t
{
fp8_impl
::
cast_from_f8
<
float
,
f8_ocp_t
::
wm
,
f8_ocp_t
::
we
,
false
>
(
x
.
AsType
<
fp8_storage_t
>
()[
Number
<
0
>
{}]),
fp8_impl
::
cast_from_f8
<
float
,
f8_ocp_t
::
wm
,
f8_ocp_t
::
we
,
false
>
(
x
.
AsType
<
fp8_storage_t
>
()[
Number
<
1
>
{}])};
#endif
}
template
<
>
template
<
>
inline
__host__
__device__
half2_t
type_convert
<
half2_t
,
float2_t
>
(
float2_t
x
)
inline
__host__
__device__
half2_t
type_convert
<
half2_t
,
float2_t
>
(
float2_t
x
)
{
{
...
@@ -428,42 +477,64 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
...
@@ -428,42 +477,64 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
// convert fp16 to fp8
// convert fp16 to fp8
template
<
>
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_
fnuz_
t
type_convert
<
f8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_t
>
(
x
);
return
f8_convert_sr
<
f8_
fnuz_
t
>
(
x
);
#else
#else
return
f8_convert_rne
<
f8_t
>
(
x
);
return
f8_convert_rne
<
f8_fnuz_t
>
(
x
);
#endif
}
// convert fp16 to fp8
template
<
>
inline
__host__
__device__
f8_ocp_t
type_convert
<
f8_ocp_t
,
half_t
>
(
half_t
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_ocp_t
>
(
x
);
#else
return
f8_convert_rne
<
f8_ocp_t
>
(
x
);
#endif
#endif
}
}
// convert fp8 to fp16
// convert fp8 to fp16
template
<
>
template
<
>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_
t
>
(
f8
_t
x
)
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_
fnuz_t
>
(
f8_fnuz
_t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_fnuz_t
,
half_t
,
negative_zero_nan
>
(
x
);
#endif
}
// convert fp32 to bf8
template
<
>
inline
__host__
__device__
bf8_fnuz_t
type_convert
<
bf8_fnuz_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_fnuz_t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_fnuz_t
>
(
x
);
#endif
#endif
}
}
// convert fp32 to bf8
// convert fp32 to bf8
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_
ocp_
t
type_convert
<
bf8_
ocp_
t
,
float
>
(
float
x
)
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_t
>
(
x
);
return
f8_convert_sr
<
bf8_
ocp_
t
>
(
x
);
#else
#else
return
f8_convert_rne
<
bf8_t
>
(
x
);
return
f8_convert_rne
<
bf8_
ocp_
t
>
(
x
);
#endif
#endif
}
}
// convert bf8 to fp32
// convert bf8 to fp32
template
<
>
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
bf8_t
>
(
bf8_t
x
)
inline
__host__
__device__
float
type_convert
<
float
,
bf8_
fnuz_
t
>
(
bf8_
fnuz_
t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
float
fval
;
float
fval
;
...
@@ -473,31 +544,42 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
...
@@ -473,31 +544,42 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
return
fval
;
return
fval
;
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
float
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
bf8_fnuz_t
,
float
,
negative_zero_nan
>
(
x
);
#endif
}
// convert fp16 to bf8
template
<
>
inline
__host__
__device__
bf8_fnuz_t
type_convert
<
bf8_fnuz_t
,
half_t
>
(
half_t
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_fnuz_t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_fnuz_t
>
(
x
);
#endif
#endif
}
}
// convert fp16 to bf8
// convert fp16 to bf8
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
bf8_
ocp_
t
type_convert
<
bf8_
ocp_
t
,
half_t
>
(
half_t
x
)
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_t
>
(
x
);
return
f8_convert_sr
<
bf8_
ocp_
t
>
(
x
);
#else
#else
return
f8_convert_rne
<
bf8_t
>
(
x
);
return
f8_convert_rne
<
bf8_
ocp_
t
>
(
x
);
#endif
#endif
}
}
// convert bf8 to fp16
// convert bf8 to fp16
template
<
>
template
<
>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
bf8_t
>
(
bf8_t
x
)
inline
__host__
__device__
half_t
type_convert
<
half_t
,
bf8_
fnuz_
t
>
(
bf8_
fnuz_
t
x
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
bf8_
fnuz_
t
,
half_t
,
negative_zero_nan
>
(
x
);
#endif
#endif
}
}
...
...
include/ck_tile/core/config.hpp
View file @
965b7ba4
...
@@ -4,10 +4,10 @@
...
@@ -4,10 +4,10 @@
#pragma once
#pragma once
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
|| defined(__gfx950__)
defined(__gfx942__)
#define __gfx9__
#define __gfx9__
#endif
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#define __gfx94__
#endif
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -62,9 +62,9 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -62,9 +62,9 @@ struct ReferenceGemm : public device::BaseOperator
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
const
int
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
AccDataType
v_acc
=
0
;
AccDataType
v_acc
{
0
}
;
ComputeTypeA
v_a
=
0
;
ComputeTypeA
v_a
{
0
}
;
ComputeTypeB
v_b
=
0
;
ComputeTypeB
v_b
{
0
}
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
...
@@ -93,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -93,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
}
CDataType
v_c
=
0
;
CDataType
v_c
{
0
}
;
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_element_op_
(
v_c
,
v_acc
);
...
...
library/include/ck/library/utility/host_tensor.hpp
View file @
965b7ba4
...
@@ -326,7 +326,7 @@ struct Tensor
...
@@ -326,7 +326,7 @@ struct Tensor
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
void
SetZero
()
{
ck
::
ranges
::
fill
<
T
>
(
mData
,
0
);
}
void
SetZero
()
{
ck
::
ranges
::
fill
<
T
>
(
mData
,
T
{
0
}
);
}
template
<
typename
F
>
template
<
typename
F
>
void
ForEach_impl
(
F
&&
f
,
std
::
vector
<
size_t
>&
idx
,
size_t
rank
)
void
ForEach_impl
(
F
&&
f
,
std
::
vector
<
size_t
>&
idx
,
size_t
rank
)
...
...
library/include/ck/library/utility/host_tensor_generator.hpp
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -37,7 +37,7 @@ struct GeneratorTensor_1<ck::half_t>
...
@@ -37,7 +37,7 @@ struct GeneratorTensor_1<ck::half_t>
float
value
=
1.0
;
float
value
=
1.0
;
template
<
typename
...
Is
>
template
<
typename
...
Is
>
ck
::
b
half_t
operator
()(
Is
...)
ck
::
half_t
operator
()(
Is
...)
{
{
return
ck
::
type_convert
<
ck
::
half_t
>
(
value
);
return
ck
::
type_convert
<
ck
::
half_t
>
(
value
);
}
}
...
@@ -62,7 +62,7 @@ struct GeneratorTensor_1<ck::f8_t>
...
@@ -62,7 +62,7 @@ struct GeneratorTensor_1<ck::f8_t>
float
value
=
1.0
;
float
value
=
1.0
;
template
<
typename
...
Is
>
template
<
typename
...
Is
>
ck
::
bhal
f_t
operator
()(
Is
...)
ck
::
f
8
_t
operator
()(
Is
...)
{
{
return
ck
::
type_convert
<
ck
::
f8_t
>
(
value
);
return
ck
::
type_convert
<
ck
::
f8_t
>
(
value
);
}
}
...
@@ -256,14 +256,33 @@ struct GeneratorTensor_Checkboard
...
@@ -256,14 +256,33 @@ struct GeneratorTensor_Checkboard
}
}
};
};
template
<
ck
::
index_t
Dim
>
/**
* @brief Is used to generate sequential values based on the specified dimension.
*
* @tparam T The type of the tensor values.
* @tparam Dim The specific dimension used for generation.
*
* GeneratorTensor_Sequential<1>{} will generate the following values for a 3x3 tensor:
*
* 0 1 2
* 0 1 2
* 0 1 2
*
* Essentially, the values generated are logical coordinates of the generated element that
* correspond to dimension Dim. E.g. for 2-dimensional tensor and Dim=1, the values are the column
* indices.
*
*/
template
<
typename
T
,
ck
::
index_t
Dim
>
struct
GeneratorTensor_Sequential
struct
GeneratorTensor_Sequential
{
{
template
<
typename
...
Ts
>
template
<
typename
...
Ts
>
float
operator
()(
Ts
...
Xs
)
const
T
operator
()(
Ts
...
Xs
)
const
{
{
std
::
array
<
ck
::
index_t
,
sizeof
...(
Ts
)
>
dims
=
{{
static_cast
<
ck
::
index_t
>
(
Xs
)...}};
std
::
array
<
ck
::
index_t
,
sizeof
...(
Ts
)
>
dims
=
{{
static_cast
<
ck
::
index_t
>
(
Xs
)...}};
return
dims
[
Dim
];
float
tmp
=
dims
[
Dim
];
return
ck
::
type_convert
<
T
>
(
tmp
);
}
}
};
};
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
965b7ba4
...
@@ -62,7 +62,7 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -62,7 +62,7 @@ function(add_instance_library INSTANCE_NAME)
endforeach
()
endforeach
()
# Do not build mha instances if gfx94 or gfx90a targets are not on the target list
# Do not build mha instances if gfx94 or gfx90a targets are not on the target list
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND NOT INST_TARGETS MATCHES
"gfx90a"
AND source MATCHES
"mha"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND NOT INST_TARGETS MATCHES
"gfx90a"
AND source MATCHES
"mha"
)
message
(
"removing mha instance
${
source
}
"
)
message
(
"removing mha instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
...
@@ -331,7 +331,7 @@ if(CK_DEVICE_CONV_INSTANCES)
...
@@ -331,7 +331,7 @@ if(CK_DEVICE_CONV_INSTANCES)
endif
()
endif
()
if
(
CK_DEVICE_MHA_INSTANCES
)
if
(
CK_DEVICE_MHA_INSTANCES
)
set
(
gpu_list
${
INST_TARGETS
}
)
set
(
gpu_list
${
INST_TARGETS
}
)
if
(
gpu_list MATCHES
"gfx94"
OR gpu_list MATCHES
"gfx90a"
)
if
(
gpu_list MATCHES
"gfx94"
OR gpu_list MATCHES
"gfx90a"
)
add_library
(
device_mha_operations STATIC
${
CK_DEVICE_MHA_INSTANCES
}
)
add_library
(
device_mha_operations STATIC
${
CK_DEVICE_MHA_INSTANCES
}
)
add_library
(
composablekernels::device_mha_operations ALIAS device_mha_operations
)
add_library
(
composablekernels::device_mha_operations ALIAS device_mha_operations
)
target_compile_features
(
device_mha_operations PUBLIC
)
target_compile_features
(
device_mha_operations PUBLIC
)
...
...
library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f8_instance.cpp
View file @
965b7ba4
...
@@ -15,7 +15,7 @@ void add_device_pool3d_fwd_ndhwc_f8_instances(
...
@@ -15,7 +15,7 @@ void add_device_pool3d_fwd_ndhwc_f8_instances(
instances
)
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
device_pool3d_fwd_ndhwc_instances
<
F8
,
F8
,
I32
,
F
8
,
ReduceOpId
,
false
>
{});
instances
,
device_pool3d_fwd_ndhwc_instances
<
F8
,
F8
,
I32
,
F
32
,
ReduceOpId
,
false
>
{});
}
}
void
add_device_pool3d_fwd_ndhwc_index_f8_instances
(
void
add_device_pool3d_fwd_ndhwc_index_f8_instances
(
...
@@ -23,7 +23,7 @@ void add_device_pool3d_fwd_ndhwc_index_f8_instances(
...
@@ -23,7 +23,7 @@ void add_device_pool3d_fwd_ndhwc_index_f8_instances(
instances
)
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
device_pool3d_fwd_ndhwc_instances
<
F8
,
F8
,
I32
,
F
8
,
ReduceOpId
,
true
>
{});
instances
,
device_pool3d_fwd_ndhwc_instances
<
F8
,
F8
,
I32
,
F
32
,
ReduceOpId
,
true
>
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -150,7 +150,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
...
@@ -150,7 +150,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
break
;
break
;
default:
default:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{
1
});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{
1
});
}
}
...
...
profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -157,7 +157,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
...
@@ -157,7 +157,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
break
;
break
;
default:
default:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
}
...
...
profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -174,7 +174,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
...
@@ -174,7 +174,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
break
;
break
;
default:
default:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
}
...
...
profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -140,7 +140,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
...
@@ -140,7 +140,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
break
;
break
;
default:
default:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
}
...
...
profiler/include/profiler/profile_gemm_impl.hpp
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -74,8 +74,8 @@ int profile_gemm_impl(int do_verification,
...
@@ -74,8 +74,8 @@ int profile_gemm_impl(int do_verification,
switch
(
init_method
)
switch
(
init_method
)
{
{
case
0
:
case
0
:
ck
::
utils
::
FillConstant
<
ADataType
>
{
static_cas
t
<
ADataType
>
(
1.
f
)}(
a_m_k
);
ck
::
utils
::
FillConstant
<
ADataType
>
{
type_conver
t
<
ADataType
>
(
1.
f
)}(
a_m_k
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
static_cas
t
<
BDataType
>
(
1.
f
)}(
b_k_n
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
type_conver
t
<
BDataType
>
(
1.
f
)}(
b_k_n
);
break
;
break
;
case
1
:
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
...
...
test/data_type/CMakeLists.txt
View file @
965b7ba4
...
@@ -9,13 +9,38 @@ if (USE_BITINT_EXTENSION_INT4)
...
@@ -9,13 +9,38 @@ if (USE_BITINT_EXTENSION_INT4)
endif
()
endif
()
endif
()
endif
()
add_gtest_executable
(
test_fp8 test_fp8.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_fp8 PRIVATE utility
)
add_custom_target
(
test_fp8
)
if
(
CK_USE_OCP_FP8
)
add_gtest_executable
(
test_fp8_ocp test_fp8_ocp.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_fp8_ocp PRIVATE utility
)
endif
()
add_gtest_executable
(
test_bf8_ocp test_bf8_ocp.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_bf8_ocp PRIVATE utility
)
endif
()
add_dependencies
(
test_fp8 test_fp8_ocp
)
add_dependencies
(
test_fp8 test_bf8_ocp
)
endif
()
endif
()
add_gtest_executable
(
test_bf8 test_bf8.cpp
)
if
(
result EQUAL 0
)
if
(
CK_USE_FNUZ_FP8
)
target_link_libraries
(
test_bf8 PRIVATE utility
)
add_gtest_executable
(
test_fp8_fnuz test_fp8_fnuz.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_fp8_fnuz PRIVATE utility
)
endif
()
add_gtest_executable
(
test_bf8_fnuz test_bf8_fnuz.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_bf8_fnuz PRIVATE utility
)
endif
()
add_dependencies
(
test_fp8 test_fp8_fnuz
)
add_dependencies
(
test_fp8 test_bf8_fnuz
)
endif
()
endif
()
add_gtest_executable
(
test_custom_type test_custom_type.cpp
)
add_gtest_executable
(
test_custom_type test_custom_type.cpp
)
...
...
test/data_type/test_bf8.cpp
→
test/data_type/test_bf8
_fnuz
.cpp
View file @
965b7ba4
...
@@ -5,158 +5,169 @@
...
@@ -5,158 +5,169 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
bf8_t
;
using
ck
::
bf8_
fnuz_
t
;
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_sr
;
using
ck
::
f8_convert_sr
;
using
ck
::
half_t
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
using
ck
::
type_convert
;
TEST
(
BF8
,
NumericLimits
)
TEST
(
BF8
FNUZ
,
NumericLimits
)
{
{
// constants given for negative zero nan mode
// constants given for negative zero nan mode
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
Min
(),
type_convert
<
bf8_t
>
(
0x04
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_
fnuz_
t
>::
Min
(),
type_convert
<
bf8_
fnuz_
t
>
(
0x04
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
Max
(),
type_convert
<
bf8_t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_
fnuz_
t
>::
Max
(),
type_convert
<
bf8_
fnuz_
t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
Lowest
(),
type_convert
<
bf8_t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_
fnuz_
t
>::
Lowest
(),
type_convert
<
bf8_
fnuz_
t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
QuietNaN
(),
type_convert
<
bf8_t
>
(
0x80
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_
fnuz_
t
>::
QuietNaN
(),
type_convert
<
bf8_
fnuz_
t
>
(
0x80
));
}
}
TEST
(
BF8
,
ConvertFP32Nearest
)
TEST
(
BF8
FNUZ
,
ConvertFP32Nearest
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-6
;
float
abs_tol
=
1e-6
;
// convert 0 float to bf8 and back, check if holds
// convert 0 float to bf8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
0.0
f
)),
abs_tol
);
// don't run the next test on gfx11 devices
// don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST
#ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to bf8 and back, check if holds
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
abs_tol
);
#endif
#endif
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR
(
57344.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
57344.0
f
)),
abs_tol
);
const
auto
max_bf8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
Max
());
// convert maximal bf8_fnuz_t to float and check if equal to 57344.0
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_fnuz_t
>
(
max_bf8_t_float
)),
abs_tol
);
// convert maximal float to bf8 and back, check if clipped to 57344.0
// convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
57344.0
f
,
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
abs_tol
);
// convert inf float to bf8_t and check if it is qNan
// convert inf float to bf8_
fnuz_
t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
QuietNaN
(
),
f8_convert_rne
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
f8_convert_rne
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
abs_tol
);
// positive norm float value to bf8 and back, check if holds
// positive norm float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
float
pos_float
=
0.0000762939
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to bf8 and back, check if holds
// negative norm float value to bf8 and back, check if holds
float
neg_float
=
-
0.0000610351
f
;
float
neg_float
=
-
0.0000610351
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to bf8 and back, check if holds
// positive subnorm float value to bf8 and back, check if holds
pos_float
=
0.0000305175
f
;
pos_float
=
0.0000305175
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to bf8 and back, check if holds
// negative subnorm float value to bf8 and back, check if holds
neg_float
=
-
0.0000152587
f
;
neg_float
=
-
0.0000152587
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
}
}
TEST
(
BF8
,
ConvertFP32Stochastic
)
TEST
(
BF8
FNUZ
,
ConvertFP32Stochastic
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-6
;
float
abs_tol
=
1e-6
;
// convert 0 float to bf8 and back, check if holds
// convert 0 float to bf8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
0.0
f
)),
abs_tol
);
// convert minimal float to bf8 and back, check if holds
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
abs_tol
);
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR
(
57344.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
57344.0
f
)),
abs_tol
);
const
auto
max_bf8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
Max
());
// convert maximal bf8_fnuz_t to float and check if equal to 57344.0
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_fnuz_t
>
(
max_bf8_t_float
)),
abs_tol
);
// convert maximal float to bf8 and back, check if clipped to 57344.0
// convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
57344.0
f
,
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
abs_tol
);
// convert inf float to bf8_t and check if it is qNan
// convert inf float to bf8_
fnuz_
t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
QuietNaN
(
),
f8_convert_sr
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
f8_convert_sr
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
abs_tol
);
// positive norm float value to bf8 and back, check if holds
// positive norm float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
float
pos_float
=
0.0000762939
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to bf8 and back, check if holds
// negative norm float value to bf8 and back, check if holds
float
neg_float
=
-
0.0000610351
f
;
float
neg_float
=
-
0.0000610351
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to bf8 and back, check if holds
// positive subnorm float value to bf8 and back, check if holds
pos_float
=
0.0000305175
f
;
pos_float
=
0.0000305175
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to bf8 and back, check if holds
// negative subnorm float value to bf8 and back, check if holds
neg_float
=
-
0.0000152587
f
;
neg_float
=
-
0.0000152587
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
}
}
TEST
(
BF8
,
ConvertFP16Nearest
)
TEST
(
BF8
FNUZ
,
ConvertFP16Nearest
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-3
;
float
abs_tol
=
1e-3
;
// convert 0 fp16 to bf8 and back, check if holds
// convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_fnuz_t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to bf8 and back, check if holds
// convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
abs_tol
);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
const
auto
max_bf8_t_half
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
Max
());
// convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0
ASSERT_NEAR
(
ASSERT_NEAR
(
half_t
{
57344.0
}
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
t
>
(
half_t
{
57344.0
}
)),
abs_tol
);
max_bf8_t_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_t
>
(
max_bf8_t_half
)),
abs_tol
);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
half_t
{
57344.0
}
,
ASSERT_NEAR
(
max_bf8_t_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
abs_tol
);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
// convert QuietNaN fp16 to bf8_
fnuz_
t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
QuietNaN
(
),
f8_convert_rne
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
f8_convert_rne
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
abs_tol
);
// positive norm fp16 value to bf8 and back, check if holds
// positive norm fp16 value to bf8 and back, check if holds
half_t
pos_half
=
half_t
{
0.0000762939
};
half_t
pos_half
=
half_t
{
0.0000762939
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to bf8 and back, check if holds
// negative norm fp16 value to bf8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.0000610351
};
half_t
neg_half
=
half_t
{
-
0.0000610351
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to bf8 and back, check if holds
// positive subnorm fp16 value to bf8 and back, check if holds
pos_half
=
half_t
{
0.0000305175
};
pos_half
=
half_t
{
0.0000305175
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to bf8 and back, check if holds
// negative subnorm fp16 value to bf8 and back, check if holds
neg_half
=
half_t
{
-
0.0000152587
};
neg_half
=
half_t
{
-
0.0000152587
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
}
}
TEST
(
BF8
,
ConvertFP16Stochastic
)
TEST
(
BF8
FNUZ
,
ConvertFP16Stochastic
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-3
;
float
abs_tol
=
1e-3
;
// convert 0 fp16 to bf8 and back, check if holds
// convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to bf8 and back, check if holds
// convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
abs_tol
);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
const
auto
max_bf8_t_half
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
Max
());
// convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0
ASSERT_NEAR
(
ASSERT_NEAR
(
half_t
{
57344.0
}
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
t
>
(
half_t
{
57344.0
}
)),
abs_tol
);
max_bf8_t_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_t
>
(
max_bf8_t_half
)),
abs_tol
);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
half_t
{
57344.0
}
,
ASSERT_NEAR
(
max_bf8_t_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
abs_tol
);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
// convert QuietNaN fp16 to bf8_
fnuz_
t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
QuietNaN
(
),
f8_convert_sr
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
f8_convert_sr
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
abs_tol
);
// positive norm fp16 value to bf8 and back, check if holds
// positive norm fp16 value to bf8 and back, check if holds
half_t
pos_half
=
half_t
{
0.0000762939
};
half_t
pos_half
=
half_t
{
0.0000762939
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to bf8 and back, check if holds
// negative norm fp16 value to bf8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.0000610351
};
half_t
neg_half
=
half_t
{
-
0.0000610351
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to bf8 and back, check if holds
// positive subnorm fp16 value to bf8 and back, check if holds
pos_half
=
half_t
{
0.0000305175
};
pos_half
=
half_t
{
0.0000305175
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to bf8 and back, check if holds
// negative subnorm fp16 value to bf8 and back, check if holds
neg_half
=
half_t
{
-
0.0000152587
};
neg_half
=
half_t
{
-
0.0000152587
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
}
}
test/data_type/test_bf8_ocp.cpp
0 → 100644
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
bf8_ocp_t
;
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_sr
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
TEST
(
BF8OCP
,
NumericLimits
)
{
// constants given for OCP FP8
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Min
(),
type_convert
<
bf8_ocp_t
>
(
0x04
));
// 0b00000100 = 2^-14
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
type_convert
<
bf8_ocp_t
>
(
0x7B
));
// 0b01111011 = 57344
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Lowest
(),
type_convert
<
bf8_ocp_t
>
(
0xFB
));
// 0b11111011 = -57344
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
QuietNaN
().
data
,
type_convert
<
bf8_ocp_t
>
(
0x7D
).
data
);
// 0b01111101
EXPECT_FALSE
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
QuietNaN
()
==
ck
::
NumericLimits
<
bf8_ocp_t
>::
QuietNaN
());
EXPECT_TRUE
(
ck
::
fp8_is_inf
(
type_convert
<
bf8_ocp_t
>
(
0xFC
))
&&
ck
::
fp8_is_inf
(
type_convert
<
bf8_ocp_t
>
(
0x7C
)));
}
TEST
(
BF8OCP
,
ConvertFP32Nearest
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to bfp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
0.0
f
)),
0.0
f
);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
const
auto
max_bf8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
());
// convert maximal bf8_ocp_t to float and check if equal to bf8 max
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
max_bf8_t_float
)),
0.0
f
);
// convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
0.0
f
);
// convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
f8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()));
// positive normal float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
// 10*2^-17
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
pos_float
)),
abs_tol
);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr
auto
neg_min_bf8
=
-
0.00006103515625
f
;
//-2^-14
ASSERT_NEAR
(
neg_min_bf8
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
neg_min_bf8
)),
0.0
f
);
// positive subnorm float value to bf8 and back, check if holds
constexpr
auto
pos_subnorm_bf8
=
0.000030517578125
f
;
// 2^-15
ASSERT_NEAR
(
pos_subnorm_bf8
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
pos_subnorm_bf8
)),
0.0
f
);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr
auto
min_subnorm_bf8
=
-
0.0000152587890625
f
;
//-2^-16
ASSERT_NEAR
(
min_subnorm_bf8
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
min_subnorm_bf8
)),
0.0
f
);
// smaller than min subnorm bf8 value to bf8 must be zero
constexpr
auto
less_than_min_subnorm
=
0.00000762939453125
f
;
// 2^-17
ASSERT_EQ
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
less_than_min_subnorm
)));
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const
auto
bf8_nan
=
f8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
quiet_NaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_bf8_is_nan
(
bf8_nan
.
data
));
}
TEST
(
BF8OCP
,
ConvertFP32Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to bfp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
0.0
f
)),
0.0
f
);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
const
auto
max_bf8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
());
// convert maximal bf8_ocp_t to float and check if equal to bf8 max
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
max_bf8_t_float
)),
0.0
f
);
// convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
0.0
f
);
// convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
f8_convert_sr
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()));
// positive normal float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
// 10*2^-17
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
pos_float
)),
abs_tol
);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr
auto
neg_min_bf8
=
-
0.00006103515625
f
;
//-2^-14
ASSERT_NEAR
(
neg_min_bf8
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
neg_min_bf8
)),
0.0
f
);
// positive subnorm float value to bf8 and back, check if holds
constexpr
auto
pos_subnorm_bf8
=
0.000030517578125
f
;
// 2^-15
ASSERT_NEAR
(
pos_subnorm_bf8
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
pos_subnorm_bf8
)),
0.0
f
);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr
auto
min_subnorm_bf8
=
-
0.0000152587890625
f
;
//-2^-16
ASSERT_NEAR
(
min_subnorm_bf8
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
min_subnorm_bf8
)),
0.0
f
);
// smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16
constexpr
auto
less_than_min_subnorm
=
0.00000762939453125
f
;
// 2^-17
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
less_than_min_subnorm
)),
0.0000152587890625
f
);
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const
auto
bf8_nan
=
f8_convert_sr
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
quiet_NaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_bf8_is_nan
(
bf8_nan
.
data
));
}
TEST
(
BF8OCP
,
ConvertFP16Nearest
)
{
// fix the tolerance value
constexpr
half_t
half_t_tol
=
1e-3
;
constexpr
half_t
half_t_zero
=
0.0
;
// convert 0 half_t to bfp8 and back, check if holds
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
half_t_zero
)),
half_t_zero
);
// convert minimal half_t to bf8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
half_t_tol
);
const
auto
max_bf8_t_half_t
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
());
// convert maximal bf8_ocp_t to half_t and check if equal to bf8 max
ASSERT_NEAR
(
max_bf8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
max_bf8_t_half_t
)),
half_t_zero
);
// convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR
(
max_bf8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
half_t_zero
);
// convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
f8_convert_rne
<
bf8_ocp_t
>
(
type_convert
<
half_t
>
(
std
::
numeric_limits
<
float
>::
infinity
())));
// positive normal bf8 value to bf8 and back, check if holds
constexpr
half_t
pos_norm_bf8
{
0.0000762939
f
};
// 10*2^-17
ASSERT_NEAR
(
pos_norm_bf8
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
pos_norm_bf8
)),
half_t_tol
);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr
half_t
neg_min_bf8
{
-
0.00006103515625
f
};
//-2^-14
ASSERT_NEAR
(
neg_min_bf8
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
neg_min_bf8
)),
half_t_zero
);
// positive subnorm bf8 value to bf8 and back, check if holds
constexpr
half_t
pos_subnorm_bf8
{
0.000030517578125
f
};
// 2^-15
ASSERT_NEAR
(
pos_subnorm_bf8
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
pos_subnorm_bf8
)),
half_t_zero
);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr
half_t
min_subnorm_bf8
{
-
0.0000152587890625
f
};
//-2^-16
ASSERT_NEAR
(
min_subnorm_bf8
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
min_subnorm_bf8
)),
half_t_zero
);
// smaller than min subnorm bf8 value to bf8 must be zero
constexpr
half_t
less_than_min_subnorm
{
0.00000762939453125
f
};
// 2^-17
ASSERT_EQ
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
less_than_min_subnorm
)));
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const
auto
bf8_nan
=
f8_convert_rne
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_bf8_is_nan
(
bf8_nan
.
data
));
}
TEST
(
BF8OCP
,
ConvertFP16Stochastic
)
{
// fix the tolerance value
constexpr
half_t
half_t_tol
=
1e-3
;
constexpr
half_t
half_t_zero
=
0.0
;
constexpr
auto
min_subnorm_bf8
=
0.0000152587890625
f
;
// 2^-16
// convert 0 half_t to bfp8 and back, check if holds
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
half_t_zero
)),
half_t_zero
);
// convert minimal half_t (6.103515625e-05) to fp8 and back
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
half_t_zero
);
const
auto
max_bf8_t_half_t
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
());
// convert maximal bf8_ocp_t to half_t and check if equal to bf8 max
ASSERT_NEAR
(
max_bf8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
max_bf8_t_half_t
)),
half_t_zero
);
// convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR
(
max_bf8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
half_t_zero
);
// convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
f8_convert_sr
<
bf8_ocp_t
>
(
type_convert
<
half_t
>
(
std
::
numeric_limits
<
float
>::
infinity
())));
// positive normal bf8 value to bf8 and back, check if holds
constexpr
half_t
pos_norm_bf8
{
0.0000762939
f
};
// 10*2^-17
ASSERT_NEAR
(
pos_norm_bf8
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
pos_norm_bf8
)),
half_t_tol
);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr
half_t
neg_min_bf8
{
-
0.00006103515625
f
};
//-2^-14
ASSERT_NEAR
(
neg_min_bf8
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
neg_min_bf8
)),
half_t_zero
);
// positive subnorm bf8 value to bf8 and back, check if holds
constexpr
half_t
pos_subnorm_bf8
{
0.000030517578125
f
};
// 2^-15
ASSERT_NEAR
(
pos_subnorm_bf8
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
pos_subnorm_bf8
)),
half_t_zero
);
// min subnorm bf8 value to bf8 and back, check if holds
ASSERT_NEAR
(
half_t
{
-
min_subnorm_bf8
},
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
half_t
{
-
min_subnorm_bf8
})),
half_t_zero
);
// smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16
constexpr
half_t
less_than_min_subnorm
{
0.00000762939453125
f
};
// 2^-17
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
less_than_min_subnorm
)),
half_t
{
min_subnorm_bf8
});
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const
auto
bf8_nan
=
f8_convert_sr
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_bf8_is_nan
(
bf8_nan
.
data
));
}
test/data_type/test_custom_type.cpp
View file @
965b7ba4
...
@@ -872,3 +872,161 @@ TEST(Complex_half, TestAsTypeReshape)
...
@@ -872,3 +872,161 @@ TEST(Complex_half, TestAsTypeReshape)
test_vec
.
at
(
num_elem
*
i
+
1
));
test_vec
.
at
(
num_elem
*
i
+
1
));
});
});
}
}
#if CK_USE_OCP_FP8
TEST
(
FP8OCP
,
TestSize
)
{
static_assert
(
std
::
is_same_v
<
f8_t
,
ck
::
f8_ocp_t
>
,
"OCP FP8 is not enabled"
);
ASSERT_EQ
(
sizeof
(
f8_t
),
sizeof
(
ck
::
fp8_storage_t
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
2
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
4
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
8
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
16
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
32
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
64
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
64
>
));
}
TEST
(
FP8OCP
,
TestAsType
)
{
static_assert
(
std
::
is_same_v
<
f8_t
,
ck
::
f8_ocp_t
>
,
"OCP FP8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
4
,
-
2
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
f8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
f8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
));
});
// copy the vector
vector_type
<
f8_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
)));
});
ck
::
non_native_vector_base
<
ck
::
f8_ocp_t
,
2
>
nnvb_f8x2
(
ck
::
type_convert
<
f8_t
>
(
-
10.0
f
));
ASSERT_EQ
(
nnvb_f8x2
.
template
AsType
<
f8_t
>()(
Number
<
0
>
{}),
ck
::
type_convert
<
f8_t
>
(
-
10.0
f
));
ASSERT_EQ
(
nnvb_f8x2
.
template
AsType
<
f8_t
>()(
Number
<
1
>
{}),
ck
::
type_convert
<
f8_t
>
(
-
10.0
f
));
}
TEST
(
FP8OCP
,
TestAsTypeReshape
)
{
static_assert
(
std
::
is_same_v
<
f8_t
,
ck
::
f8_ocp_t
>
,
"OCP FP8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
8
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
/
256
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
f8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
f8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
));
});
// copy the first half of a vector
vector_type
<
f8_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
f8_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
)));
});
}
TEST
(
BF8OCP
,
TestSize
)
{
static_assert
(
std
::
is_same_v
<
bf8_t
,
ck
::
bf8_ocp_t
>
,
"OCP BF8 is not enabled"
);
ASSERT_EQ
(
sizeof
(
bf8_t
),
sizeof
(
ck
::
fp8_storage_t
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
2
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
4
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
8
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
16
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
32
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
64
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
64
>
));
}
TEST
(
BF8OCP
,
TestAsType
)
{
static_assert
(
std
::
is_same_v
<
bf8_t
,
ck
::
bf8_ocp_t
>
,
"OCP BF8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
4
,
-
2
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
bf8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
bf8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
));
});
// copy the vector
vector_type
<
bf8_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
)));
});
ck
::
non_native_vector_base
<
bf8_t
,
2
>
nnvb_bf8x2
(
ck
::
type_convert
<
bf8_t
>
(
-
10.0
f
));
ASSERT_EQ
(
nnvb_bf8x2
.
template
AsType
<
bf8_t
>()(
Number
<
0
>
{}),
ck
::
type_convert
<
bf8_t
>
(
-
10.0
f
));
ASSERT_EQ
(
nnvb_bf8x2
.
template
AsType
<
bf8_t
>()(
Number
<
1
>
{}),
ck
::
type_convert
<
bf8_t
>
(
-
10.0
f
));
}
TEST
(
BF8OCP
,
TestAsTypeReshape
)
{
static_assert
(
std
::
is_same_v
<
bf8_t
,
ck
::
bf8_ocp_t
>
,
"OCP BF8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
8
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
/
256
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
bf8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
bf8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
));
});
// copy the first half of a vector
vector_type
<
bf8_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
bf8_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
)));
});
}
#endif
test/data_type/test_fp8.cpp
→
test/data_type/test_fp8
_fnuz
.cpp
View file @
965b7ba4
...
@@ -7,154 +7,171 @@
...
@@ -7,154 +7,171 @@
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_sr
;
using
ck
::
f8_convert_sr
;
using
ck
::
f8_t
;
using
ck
::
f8_
fnuz_
t
;
using
ck
::
half_t
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
using
ck
::
type_convert
;
TEST
(
FP8
,
NumericLimits
)
TEST
(
FP8
FNUZ
,
NumericLimits
)
{
{
// constants given for negative zero nan mode
// constants given for negative zero nan mode
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Min
(),
type_convert
<
f8_t
>
(
0x08
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_
fnuz_
t
>::
Min
(),
type_convert
<
f8_
fnuz_
t
>
(
0x08
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Max
(),
type_convert
<
f8_t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_
fnuz_
t
>::
Max
(),
type_convert
<
f8_
fnuz_
t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Lowest
(),
type_convert
<
f8_t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_
fnuz_
t
>::
Lowest
(),
type_convert
<
f8_
fnuz_
t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
QuietNaN
(),
type_convert
<
f8_t
>
(
0x80
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_
fnuz_
t
>::
QuietNaN
(),
type_convert
<
f8_
fnuz_
t
>
(
0x80
));
}
}
TEST
(
FP8
,
ConvertFP32Nearest
)
TEST
(
FP8
FNUZ
,
ConvertFP32Nearest
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-6
;
float
abs_tol
=
1e-6
;
// convert 0 float to fp8 and back, check if holds
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
0.0
f
)),
abs_tol
);
// don't run the next test on gfx11 devices
// don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST
#ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to fp8 and back, check if holds
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
abs_tol
);
#endif
#endif
// convert maximal f8_t to float and check if equal to 240.0
ASSERT_NEAR
(
240.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
240.0
f
)),
abs_tol
);
const
auto
max_f8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
Max
());
// convert maximal float to fp8 and back, check if clipped to 240.0
// convert maximal f8_fnuz_t to float and check if equal to fp8 max
ASSERT_NEAR
(
240.0
f
,
ASSERT_NEAR
(
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_fnuz_t
>
(
max_f8_t_float
)),
abs_tol
);
// XXX: FNUZ f8_convert_rne behavior is inconsistent.
// Clipping large values to fp8 max (saturation to finite) contradicts converting inf float to
// fp8 qNAN (no saturation).
// convert maximal float to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_fnuz_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
abs_tol
);
// convert inf float to f8_t and check if it is qNan
// convert inf float to f8_
fnuz_
t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
QuietNaN
(
),
f8_convert_rne
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
f8_convert_rne
<
f8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
abs_tol
);
// positive norm float value to fp8 and back, check if holds
// positive norm float value to fp8 and back, check if holds
float
pos_float
=
0.017578125
f
;
float
pos_float
=
0.017578125
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to fp8 and back, check if holds
// negative norm float value to fp8 and back, check if holds
float
neg_float
=
-
0.015625
f
;
float
neg_float
=
-
0.015625
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to fp8 and back, check if holds
// positive subnorm float value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
pos_float
=
0.00390625
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to fp8 and back, check if holds
// negative subnorm float value to fp8 and back, check if holds
neg_float
=
-
0.001953125
f
;
neg_float
=
-
0.001953125
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
}
}
TEST
(
FP8
,
ConvertFP32Stochastic
)
TEST
(
FP8
FNUZ
,
ConvertFP32Stochastic
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-6
;
float
abs_tol
=
1e-6
;
// convert 0 float to fp8 and back, check if holds
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
0.0
f
)),
abs_tol
);
// convert minimal float to fp8 and back, check if holds
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
abs_tol
);
// convert maximal f8_t to float and check if equal to 240.0
ASSERT_NEAR
(
240.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
240.0
f
)),
abs_tol
);
const
auto
max_f8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
Max
());
// convert maximal float to fp8 and back, check if clipped to 240.0
// convert maximal f8_fnuz_t to float and check if equal to fp8 max
ASSERT_NEAR
(
240.0
f
,
ASSERT_NEAR
(
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_fnuz_t
>
(
max_f8_t_float
)),
abs_tol
);
// convert maximal float to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_fnuz_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
abs_tol
);
// convert inf float to f8_t and check if it is qNan
// convert inf float to f8_
fnuz_
t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
QuietNaN
(
),
f8_convert_sr
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
f8_convert_sr
<
f8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
abs_tol
);
// positive norm float value to fp8 and back, check if holds
// positive norm float value to fp8 and back, check if holds
float
pos_float
=
0.017578125
f
;
float
pos_float
=
0.017578125
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to fp8 and back, check if holds
// negative norm float value to fp8 and back, check if holds
float
neg_float
=
-
0.015625
f
;
float
neg_float
=
-
0.015625
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to fp8 and back, check if holds
// positive subnorm float value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
pos_float
=
0.00390625
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to fp8 and back, check if holds
// negative subnorm float value to fp8 and back, check if holds
neg_float
=
-
0.001953125
f
;
neg_float
=
-
0.001953125
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
}
}
TEST
(
FP8
,
ConvertFP16Nearest
)
TEST
(
FP8
FNUZ
,
ConvertFP16Nearest
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-3
;
float
abs_tol
=
1e-3
;
// convert 0 fp16 to fp8 and back, check if holds
// convert 0 fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to fp8 and back, check if holds
// convert minimal fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
abs_tol
);
// convert maximal f8_t to fp16 and check if equal to 240.0
ASSERT_NEAR
(
half_t
{
240.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
half_t
{
240.0
})),
abs_tol
);
const
auto
max_f8_t_half
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
Max
());
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
// convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max
ASSERT_NEAR
(
half_t
{
240.0
},
ASSERT_NEAR
(
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
max_f8_t_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_fnuz_t
>
(
max_f8_t_half
)),
abs_tol
);
// convert maximal fp16 to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR
(
max_f8_t_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_fnuz_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
abs_tol
);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
// convert QuietNaN fp16 to f8_
fnuz_
t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
QuietNaN
(
),
f8_convert_rne
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
f8_convert_rne
<
f8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
abs_tol
);
// positive norm fp16 value to fp8 and back, check if holds
// positive norm fp16 value to fp8 and back, check if holds
half_t
pos_half
=
half_t
{
0.017578125
};
half_t
pos_half
=
half_t
{
0.017578125
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to fp8 and back, check if holds
// negative norm fp16 value to fp8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.015625
};
half_t
neg_half
=
half_t
{
-
0.015625
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to fp8 and back, check if holds
// positive subnorm fp16 value to fp8 and back, check if holds
pos_half
=
half_t
{
0.00390625
};
pos_half
=
half_t
{
0.00390625
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to fp8 and back, check if holds
// negative subnorm fp16 value to fp8 and back, check if holds
neg_half
=
half_t
{
-
0.001953125
};
neg_half
=
half_t
{
-
0.001953125
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
}
}
TEST
(
FP8
,
ConvertFP16Stochastic
)
TEST
(
FP8
FNUZ
,
ConvertFP16Stochastic
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-3
;
float
abs_tol
=
1e-3
;
// convert 0 fp16 to fp8 and back, check if holds
// convert 0 fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to fp8 and back, check if holds
// convert minimal fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
abs_tol
);
// convert maximal f8_t to fp16 and check if equal to 240.0
ASSERT_NEAR
(
half_t
{
240.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
half_t
{
240.0
})),
abs_tol
);
const
auto
max_f8_t_half
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
Max
());
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
// convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max
ASSERT_NEAR
(
half_t
{
240.0
},
ASSERT_NEAR
(
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
max_f8_t_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_fnuz_t
>
(
max_f8_t_half
)),
abs_tol
);
// convert maximal fp16 to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR
(
max_f8_t_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_fnuz_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
abs_tol
);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
// convert QuietNaN fp16 to f8_
fnuz_
t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
QuietNaN
(
),
f8_convert_sr
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
f8_convert_sr
<
f8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
abs_tol
);
// positive norm fp16 value to fp8 and back, check if holds
// positive norm fp16 value to fp8 and back, check if holds
half_t
pos_half
=
half_t
{
0.017578125
};
half_t
pos_half
=
half_t
{
0.017578125
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to fp8 and back, check if holds
// negative norm fp16 value to fp8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.015625
};
half_t
neg_half
=
half_t
{
-
0.015625
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to fp8 and back, check if holds
// positive subnorm fp16 value to fp8 and back, check if holds
pos_half
=
half_t
{
0.00390625
};
pos_half
=
half_t
{
0.00390625
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to fp8 and back, check if holds
// negative subnorm fp16 value to fp8 and back, check if holds
neg_half
=
half_t
{
-
0.001953125
};
neg_half
=
half_t
{
-
0.001953125
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
}
}
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment