Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
79a4b17f
Commit
79a4b17f
authored
Oct 03, 2024
by
Andriy Roshchenko
Browse files
Initial introduction of OFP8 data types.
parent
171ed358
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
355 additions
and
89 deletions
+355
-89
CMakeLists.txt
CMakeLists.txt
+8
-0
CMakePresets.json
CMakePresets.json
+81
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+195
-30
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+71
-59
No files found.
CMakeLists.txt
View file @
79a4b17f
...
...
@@ -190,6 +190,14 @@ if (GPU_TARGETS)
add_definitions
(
-DCK_USE_WMMA
)
set
(
CK_USE_WMMA
"ON"
)
endif
()
if
(
GPU_TARGETS MATCHES
"gfx12"
OR GPU_TARGETS MATCHES
"gfx950"
)
add_definitions
(
-DCK_USE_OCP_FP8
)
set
(
CK_USE_OCP_FP8
"ON"
)
endif
()
if
(
GPU_TARGETS MATCHES
"gfx90a"
OR GPU_TARGETS MATCHES
"gfx94"
)
add_definitions
(
-DCK_USE_FNUZ_FP8
)
set
(
CK_USE_FNUZ_FP8
"ON"
)
endif
()
else
()
add_definitions
(
-DCK_USE_WMMA -DCK_USE_XDL
)
set
(
CK_USE_XDL
"ON"
)
...
...
CMakePresets.json
0 → 100644
View file @
79a4b17f
{
"version"
:
3
,
"configurePresets"
:
[
{
"name"
:
"linux-debug"
,
"displayName"
:
"Linux Debug"
,
"hidden"
:
true
,
"generator"
:
"Unix Makefiles"
,
"binaryDir"
:
"${sourceDir}/build/${presetName}"
,
"installDir"
:
"${sourceDir}/build/install/${presetName}"
,
"cacheVariables"
:
{
"CMAKE_BUILD_TYPE"
:
"Debug"
,
"CMAKE_EXPORT_COMPILE_COMMANDS"
:
"ON"
,
"GPU_TARGETS"
:
"gfx90a"
,
"BUILD_DEV"
:
"ON"
,
"CMAKE_CXX_COMPILER"
:
"/opt/rocm/llvm/bin/clang++"
,
"CMAKE_PREFIX_PATH"
:
"/opt/rocm"
},
"condition"
:
{
"type"
:
"equals"
,
"lhs"
:
"${hostSystemName}"
,
"rhs"
:
"Linux"
}
},
{
"name"
:
"MI355-debug"
,
"displayName"
:
"MI355 Debug"
,
"inherits"
:
"linux-debug"
,
"description"
:
"Development Environment for MI355."
,
"environment"
:
{
"NONE"
:
""
},
"cacheVariables"
:
{
"CMAKE_BUILD_TYPE"
:
"Debug"
,
"CMAKE_CXX_FLAGS"
:
"-O0"
}
},
{
"name"
:
"MI355-release"
,
"displayName"
:
"MI355 Release"
,
"inherits"
:
"MI355-debug"
,
"cacheVariables"
:
{
"CMAKE_BUILD_TYPE"
:
"Release"
,
"CMAKE_CXX_FLAGS"
:
"-O3"
}
}
],
"buildPresets"
:
[
{
"name"
:
"Debug"
,
"hidden"
:
true
,
"configuration"
:
"Debug"
},
{
"name"
:
"Release"
,
"hidden"
:
true
,
"configuration"
:
"Release"
},
{
"name"
:
"MI355-debug"
,
"displayName"
:
"MI355"
,
"configurePreset"
:
"MI355-debug"
,
"description"
:
"Build Environment for MI355 Debug."
,
"inherits"
:
[
"Debug"
],
"jobs"
:
128
},
{
"name"
:
"MI355-release"
,
"displayName"
:
"MI355"
,
"configurePreset"
:
"MI355-release"
,
"description"
:
"Build Environment for MI355 Release."
,
"inherits"
:
[
"Release"
],
"jobs"
:
128
}
]
}
include/ck/utility/data_type.hpp
View file @
79a4b17f
...
...
@@ -5,13 +5,51 @@
#include "ck/utility/statically_indexed_array.hpp"
#ifdef CK_USE_FNUZ_FP8
#define CK_USE_FNUZ_FP8 1
#else
#define CK_USE_FNUZ_FP8 0
#endif
#ifdef CK_USE_OCP_FP8
#define CK_USE_OCP_FP8 1
#else
#define CK_USE_OCP_FP8 0
#endif
namespace
ck
{
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
int4_t
=
_BitInt
(
4
);
using
f8_t
=
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
f8_fnuz_t
=
_BitInt
(
8
);
using
bf8_fnuz_t
=
unsigned
_BitInt
(
8
);
typedef
unsigned
char
__hip_fp8_storage_t
;
struct
f8_ocp_t
{
using
type
=
__hip_fp8_storage_t
;
type
data
;
};
struct
bf8_ocp_t
{
using
type
=
__hip_fp8_storage_t
;
type
data
;
};
#if CK_USE_OCP_FP8
using
f8_t
=
f8_ocp_t
;
using
bf8_t
=
bf8_ocp_t
;
#define CK_FP8_TYPE_FNUZ 0
#define CK_FP8_TYPE_OCP 1
#else
using
f8_t
=
f8_fnuz_t
;
using
bf8_t
=
bf8_fnuz_t
;
#define CK_FP8_TYPE_FNUZ 1
#define CK_FP8_TYPE_OCP 0
#endif
// vector_type
template
<
typename
T
,
index_t
N
>
...
...
@@ -150,19 +188,33 @@ struct scalar_type<int4_t>
#endif
template
<
>
struct
scalar_type
<
f8_t
>
struct
scalar_type
<
f8_
fnuz_
t
>
{
using
type
=
f8_t
;
using
type
=
f8_
fnuz_
t
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
struct
scalar_type
<
bf8_t
>
struct
scalar_type
<
bf8_
fnuz_
t
>
{
using
type
=
bf8_t
;
using
type
=
bf8_
fnuz_
t
;
static
constexpr
index_t
vector_size
=
1
;
};
// template <>
// struct scalar_type<f8_ocp_t>
// {
// using type = f8_ocp_t;
// static constexpr index_t vector_size = 1;
// };
// template <>
// struct scalar_type<bf8_ocp_t>
// {
// using type = bf8_ocp_t;
// static constexpr index_t vector_size = 1;
// };
template
<
>
struct
scalar_type
<
bool
>
{
...
...
@@ -1037,20 +1089,71 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// f8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
using
f8x2_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
2
>::
type
;
using
f8x4_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
4
>::
type
;
using
f8x8_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
8
>::
type
;
using
f8x16_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
16
>::
type
;
using
f8x32_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
32
>::
type
;
using
f8x64_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
64
>::
type
;
// bf8
using
bf8x2_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
2
>::
type
;
using
bf8x4_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
4
>::
type
;
using
bf8x8_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
8
>::
type
;
using
bf8x16_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
16
>::
type
;
using
bf8x32_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
32
>::
type
;
using
bf8x64_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
64
>::
type
;
// f8
using
f8x2_ocp_t
=
typename
vector_type
<
f8_ocp_t
::
type
,
2
>::
type
;
using
f8x4_ocp_t
=
typename
vector_type
<
f8_ocp_t
::
type
,
4
>::
type
;
using
f8x8_ocp_t
=
typename
vector_type
<
f8_ocp_t
::
type
,
8
>::
type
;
using
f8x16_ocp_t
=
typename
vector_type
<
f8_ocp_t
::
type
,
16
>::
type
;
using
f8x32_ocp_t
=
typename
vector_type
<
f8_ocp_t
::
type
,
32
>::
type
;
using
f8x64_ocp_t
=
typename
vector_type
<
f8_ocp_t
::
type
,
64
>::
type
;
// bf8
using
bf8x2_ocp_t
=
typename
vector_type
<
bf8_ocp_t
::
type
,
2
>::
type
;
using
bf8x4_ocp_t
=
typename
vector_type
<
bf8_ocp_t
::
type
,
4
>::
type
;
using
bf8x8_ocp_t
=
typename
vector_type
<
bf8_ocp_t
::
type
,
8
>::
type
;
using
bf8x16_ocp_t
=
typename
vector_type
<
bf8_ocp_t
::
type
,
16
>::
type
;
using
bf8x32_ocp_t
=
typename
vector_type
<
bf8_ocp_t
::
type
,
32
>::
type
;
using
bf8x64_ocp_t
=
typename
vector_type
<
bf8_ocp_t
::
type
,
64
>::
type
;
#if CK_FP8_TYPE_OCP
// f8
using
f8x2_t
=
f8x2_ocp_t
;
using
f8x4_t
=
f8x4_ocp_t
;
using
f8x8_t
=
f8x8_ocp_t
;
using
f8x16_t
=
f8x16_ocp_t
;
using
f8x32_t
=
f8x32_ocp_t
;
using
f8x64_t
=
f8x64_ocp_t
;
// bf8
using
bf8x2_t
=
bf8x2_ocp_t
;
using
bf8x4_t
=
bf8x4_ocp_t
;
using
bf8x8_t
=
bf8x8_ocp_t
;
using
bf8x16_t
=
bf8x16_ocp_t
;
using
bf8x32_t
=
bf8x32_ocp_t
;
using
bf8x64_t
=
bf8x64_ocp_t
;
#elif CK_FP8_TYPE_FNUZ
// f8
using
f8x2_t
=
f8x2_fnuz_t
;
using
f8x4_t
=
f8x4_fnuz_t
;
using
f8x8_t
=
f8x8_fnuz_t
;
using
f8x16_t
=
f8x16_fnuz_t
;
using
f8x32_t
=
f8x32_fnuz_t
;
using
f8x64_t
=
f8x64_fnuz_t
;
// bf8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
using
bf8x2_t
=
bf8x2_fnuz_t
;
using
bf8x4_t
=
bf8x4_fnuz_t
;
using
bf8x8_t
=
bf8x8_fnuz_t
;
using
bf8x16_t
=
bf8x16_fnuz_t
;
using
bf8x32_t
=
bf8x32_fnuz_t
;
using
bf8x64_t
=
bf8x64_fnuz_t
;
#endif
// u8
// i8
using
uint8x2_t
=
typename
vector_type
<
uint8_t
,
2
>::
type
;
...
...
@@ -1107,7 +1210,7 @@ struct NumericLimits<int4_t>
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
f8_t
>
struct
NumericLimits
<
f8_
fnuz_
t
>
{
// negative zero nan mode with exp bias = 8
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
...
...
@@ -1120,17 +1223,17 @@ struct NumericLimits<f8_t>
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
f8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_
fnuz_
t
Min
()
{
return
f8_
fnuz_
t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
f8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_
fnuz_
t
Max
()
{
return
f8_
fnuz_
t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
f8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_
fnuz_
t
Lowest
()
{
return
f8_
fnuz_
t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
f8_
fnuz_
t
QuietNaN
()
{
return
f8_
fnuz_
t
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
bf8_t
>
struct
NumericLimits
<
bf8_
fnuz_
t
>
{
// negative zero nan mode with exp bias = 16
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100
...
...
@@ -1143,13 +1246,59 @@ struct NumericLimits<bf8_t>
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__
__device__
static
constexpr
bf8_t
Min
()
{
return
bf8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_
fnuz_
t
Min
()
{
return
bf8_
fnuz_
t
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_t
Max
()
{
return
bf8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_
fnuz_
t
Max
()
{
return
bf8_
fnuz_
t
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_t
Lowest
()
{
return
bf8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_
fnuz_
t
Lowest
()
{
return
bf8_
fnuz_
t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
bf8_fnuz_t
QuietNaN
()
{
return
bf8_fnuz_t
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
f8_ocp_t
>
{
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000 = 2^-6
static
constexpr
uint8_t
binary_max
=
0x7E
;
// 0b01111110 = 448
static
constexpr
uint8_t
binary_lowest
=
0xFE
;
// 0b11111110 = -448
static
constexpr
uint8_t
binary_qnan
=
0x7F
;
// 0b01111111
__host__
__device__
static
constexpr
f8_ocp_t
Min
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_ocp_t
Max
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_ocp_t
Lowest
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_ocp_t
QuietNaN
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
bf8_ocp_t
>
{
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100 = 2^-14
static
constexpr
uint8_t
binary_max
=
0x7B
;
// 0b01111011 = 57344
static
constexpr
uint8_t
binary_lowest
=
0xFB
;
// 0b11111011 = -57344
static
constexpr
uint8_t
binary_qnan
=
0x7D
;
// 0b01111101
__host__
__device__
static
constexpr
bf8_ocp_t
Min
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
Max
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
Lowest
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
QuietNaN
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_qnan
);
}
};
template
<
typename
T
>
...
...
@@ -1192,7 +1341,7 @@ struct NumericUtils<half_t>
};
template
<
>
struct
NumericUtils
<
f8_t
>
struct
NumericUtils
<
f8_
fnuz_
t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
...
...
@@ -1201,11 +1350,27 @@ struct NumericUtils<f8_t>
};
template
<
>
struct
NumericUtils
<
bf8_t
>
struct
NumericUtils
<
bf8_
fnuz_
t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
};
template
<
>
struct
NumericUtils
<
f8_ocp_t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
7
;
};
template
<
>
struct
NumericUtils
<
bf8_ocp_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
15
;
};
}
// namespace ck
include/ck/utility/type_convert.hpp
View file @
79a4b17f
...
...
@@ -163,7 +163,7 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_sr
<
f8_
fnuz_
t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
...
...
@@ -189,33 +189,35 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
return
utils
::
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
cast_to_f8
<
float
,
f8_
fnuz_
t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp16 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_sr
<
f8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx94__)
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_sr
<
f8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
half_t
,
f8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp32 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_sr
<
bf8_
fnuz_
t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
...
...
@@ -240,28 +242,32 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
return
utils
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
float
,
bf8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp16 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_sr
<
bf8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx94__)
// convert to float and use native converion
return
f8_convert_sr
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_sr
<
bf8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
half_t
,
bf8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
...
...
@@ -271,7 +277,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_rne
<
f8_
fnuz_
t
,
float
>
(
float
x
)
{
#if defined(__gfx94__)
union
...
...
@@ -296,32 +302,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
cast_to_f8
<
float
,
f8_
fnuz_
t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp16 to fp8 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_rne
<
f8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx94__)
// convert to float and use native converion
return
f8_convert_rne
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_rne
<
f8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
half_t
,
f8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp32 to bf8 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_rne
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_rne
<
bf8_
fnuz_
t
,
float
>
(
float
x
)
{
#if defined(__gfx94__)
union
...
...
@@ -345,44 +353,48 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
float
,
bf8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp16 to bf8 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_rne
<
bf8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_rne
<
bf8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx94__)
// convert to float and use native converion
return
f8_convert_rne
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_rne
<
bf8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
half_t
,
bf8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp32 to fp8
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_
fnuz_
t
type_convert
<
f8_
fnuz_
t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_t
>
(
x
);
return
f8_convert_sr
<
f8_
fnuz_
t
>
(
x
);
#else
return
f8_convert_rne
<
f8_t
>
(
x
);
return
f8_convert_rne
<
f8_
fnuz_
t
>
(
x
);
#endif
}
// convert fp8 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f8_
t
>
(
f8
_t
x
)
inline
__host__
__device__
float
type_convert
<
float
,
f8_
fnuz_t
>
(
f8_fnuz
_t
x
)
{
#if defined(__gfx94__)
float
fval
;
...
...
@@ -392,7 +404,7 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
return
fval
;
#else
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
float
,
negative_zero_nan
>
(
x
);
#endif
}
...
...
@@ -404,14 +416,14 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
return
__builtin_amdgcn_cvt_pk_f32_fp8
(
i16val
,
0
);
#else
constexpr
bool
negative_zero_nan
=
true
;
const
auto
f8x2_v
=
vector_type
<
f8_t
,
2
>
(
x
);
const
auto
f8x2_v
=
vector_type
<
f8_
fnuz_
t
,
2
>
(
x
);
vector_type
<
float
,
2
>
f32x2_v
;
f32x2_v
.
template
AsType
<
float
>()(
Number
<
0
>
{})
=
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_t
>()[
Number
<
0
>
{}]);
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_
fnuz_
t
>()[
Number
<
0
>
{}]);
f32x2_v
.
template
AsType
<
float
>()(
Number
<
1
>
{})
=
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_t
>()[
Number
<
1
>
{}]);
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_
fnuz_
t
>()[
Number
<
1
>
{}]);
return
f32x2_v
.
template
AsType
<
float2_t
>()[
Number
<
0
>
{}];
#endif
}
...
...
@@ -428,42 +440,42 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
// convert fp16 to fp8
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_
fnuz_
t
type_convert
<
f8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_t
>
(
x
);
return
f8_convert_sr
<
f8_
fnuz_
t
>
(
x
);
#else
return
f8_convert_rne
<
f8_t
>
(
x
);
return
f8_convert_rne
<
f8_
fnuz_
t
>
(
x
);
#endif
}
// convert fp8 to fp16
template
<
>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_
t
>
(
f8
_t
x
)
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_
fnuz_t
>
(
f8_fnuz
_t
x
)
{
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
half_t
,
negative_zero_nan
>
(
x
);
#endif
}
// convert fp32 to bf8
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_
fnuz_
t
type_convert
<
bf8_
fnuz_
t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_t
>
(
x
);
return
f8_convert_sr
<
bf8_
fnuz_
t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_t
>
(
x
);
return
f8_convert_rne
<
bf8_
fnuz_
t
>
(
x
);
#endif
}
// convert bf8 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
bf8_t
>
(
bf8_t
x
)
inline
__host__
__device__
float
type_convert
<
float
,
bf8_
fnuz_
t
>
(
bf8_
fnuz_
t
x
)
{
#if defined(__gfx94__)
float
fval
;
...
...
@@ -473,31 +485,31 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
return
fval
;
#else
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
float
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
bf8_
fnuz_
t
,
float
,
negative_zero_nan
>
(
x
);
#endif
}
// convert fp16 to bf8
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
bf8_
fnuz_
t
type_convert
<
bf8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_t
>
(
x
);
return
f8_convert_sr
<
bf8_
fnuz_
t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_t
>
(
x
);
return
f8_convert_rne
<
bf8_
fnuz_
t
>
(
x
);
#endif
}
// convert bf8 to fp16
template
<
>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
bf8_t
>
(
bf8_t
x
)
inline
__host__
__device__
half_t
type_convert
<
half_t
,
bf8_
fnuz_
t
>
(
bf8_
fnuz_
t
x
)
{
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
bf8_
fnuz_
t
,
half_t
,
negative_zero_nan
>
(
x
);
#endif
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment