Commit 79a4b17f authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Initial introduction of OFP8 data types.

parent 171ed358
......@@ -190,6 +190,14 @@ if (GPU_TARGETS)
add_definitions(-DCK_USE_WMMA)
set(CK_USE_WMMA "ON")
endif()
if (GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx950")
add_definitions(-DCK_USE_OCP_FP8)
set(CK_USE_OCP_FP8 "ON")
endif()
if (GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx94")
add_definitions(-DCK_USE_FNUZ_FP8)
set(CK_USE_FNUZ_FP8 "ON")
endif()
else()
add_definitions(-DCK_USE_WMMA -DCK_USE_XDL)
set(CK_USE_XDL "ON")
......
{
"version": 3,
"configurePresets": [
{
"name": "linux-debug",
"displayName": "Linux Debug",
"hidden": true,
"generator": "Unix Makefiles",
"binaryDir": "${sourceDir}/build/${presetName}",
"installDir": "${sourceDir}/build/install/${presetName}",
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Debug",
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
"GPU_TARGETS": "gfx90a",
"BUILD_DEV": "ON",
"CMAKE_CXX_COMPILER": "/opt/rocm/llvm/bin/clang++",
"CMAKE_PREFIX_PATH": "/opt/rocm"
},
"condition": {
"type": "equals",
"lhs": "${hostSystemName}",
"rhs": "Linux"
}
},
{
"name": "MI355-debug",
"displayName": "MI355 Debug",
"inherits": "linux-debug",
"description": "Development Environment for MI355.",
"environment": {
"NONE": ""
},
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Debug",
"CMAKE_CXX_FLAGS": "-O0"
}
},
{
"name": "MI355-release",
"displayName": "MI355 Release",
"inherits": "MI355-debug",
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Release",
"CMAKE_CXX_FLAGS": "-O3"
}
}
],
"buildPresets": [
{
"name": "Debug",
"hidden": true,
"configuration": "Debug"
},
{
"name": "Release",
"hidden": true,
"configuration": "Release"
},
{
"name": "MI355-debug",
"displayName": "MI355",
"configurePreset": "MI355-debug",
"description": "Build Environment for MI355 Debug.",
"inherits": [
"Debug"
],
"jobs": 128
},
{
"name": "MI355-release",
"displayName": "MI355",
"configurePreset": "MI355-release",
"description": "Build Environment for MI355 Release.",
"inherits": [
"Release"
],
"jobs": 128
}
]
}
......@@ -5,13 +5,51 @@
#include "ck/utility/statically_indexed_array.hpp"
#ifdef CK_USE_FNUZ_FP8
#define CK_USE_FNUZ_FP8 1
#else
#define CK_USE_FNUZ_FP8 0
#endif
#ifdef CK_USE_OCP_FP8
#define CK_USE_OCP_FP8 1
#else
#define CK_USE_OCP_FP8 0
#endif
namespace ck {
using bhalf_t = ushort;
using half_t = _Float16;
using int4_t = _BitInt(4);
using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8);
using f8_fnuz_t = _BitInt(8);
using bf8_fnuz_t = unsigned _BitInt(8);
typedef unsigned char __hip_fp8_storage_t;
struct f8_ocp_t
{
using type = __hip_fp8_storage_t;
type data;
};
struct bf8_ocp_t
{
using type = __hip_fp8_storage_t;
type data;
};
#if CK_USE_OCP_FP8
using f8_t = f8_ocp_t;
using bf8_t = bf8_ocp_t;
#define CK_FP8_TYPE_FNUZ 0
#define CK_FP8_TYPE_OCP 1
#else
using f8_t = f8_fnuz_t;
using bf8_t = bf8_fnuz_t;
#define CK_FP8_TYPE_FNUZ 1
#define CK_FP8_TYPE_OCP 0
#endif
// vector_type
template <typename T, index_t N>
......@@ -150,19 +188,33 @@ struct scalar_type<int4_t>
#endif
template <>
struct scalar_type<f8_t>
struct scalar_type<f8_fnuz_t>
{
using type = f8_t;
using type = f8_fnuz_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bf8_t>
struct scalar_type<bf8_fnuz_t>
{
using type = bf8_t;
using type = bf8_fnuz_t;
static constexpr index_t vector_size = 1;
};
// template <>
// struct scalar_type<f8_ocp_t>
// {
// using type = f8_ocp_t;
// static constexpr index_t vector_size = 1;
// };
// template <>
// struct scalar_type<bf8_ocp_t>
// {
// using type = bf8_ocp_t;
// static constexpr index_t vector_size = 1;
// };
template <>
struct scalar_type<bool>
{
......@@ -1037,20 +1089,71 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// f8
using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type;
using f8x2_fnuz_t = typename vector_type<f8_fnuz_t, 2>::type;
using f8x4_fnuz_t = typename vector_type<f8_fnuz_t, 4>::type;
using f8x8_fnuz_t = typename vector_type<f8_fnuz_t, 8>::type;
using f8x16_fnuz_t = typename vector_type<f8_fnuz_t, 16>::type;
using f8x32_fnuz_t = typename vector_type<f8_fnuz_t, 32>::type;
using f8x64_fnuz_t = typename vector_type<f8_fnuz_t, 64>::type;
// bf8
using bf8x2_fnuz_t = typename vector_type<bf8_fnuz_t, 2>::type;
using bf8x4_fnuz_t = typename vector_type<bf8_fnuz_t, 4>::type;
using bf8x8_fnuz_t = typename vector_type<bf8_fnuz_t, 8>::type;
using bf8x16_fnuz_t = typename vector_type<bf8_fnuz_t, 16>::type;
using bf8x32_fnuz_t = typename vector_type<bf8_fnuz_t, 32>::type;
using bf8x64_fnuz_t = typename vector_type<bf8_fnuz_t, 64>::type;
// f8
using f8x2_ocp_t = typename vector_type<f8_ocp_t::type, 2>::type;
using f8x4_ocp_t = typename vector_type<f8_ocp_t::type, 4>::type;
using f8x8_ocp_t = typename vector_type<f8_ocp_t::type, 8>::type;
using f8x16_ocp_t = typename vector_type<f8_ocp_t::type, 16>::type;
using f8x32_ocp_t = typename vector_type<f8_ocp_t::type, 32>::type;
using f8x64_ocp_t = typename vector_type<f8_ocp_t::type, 64>::type;
// bf8
using bf8x2_ocp_t = typename vector_type<bf8_ocp_t::type, 2>::type;
using bf8x4_ocp_t = typename vector_type<bf8_ocp_t::type, 4>::type;
using bf8x8_ocp_t = typename vector_type<bf8_ocp_t::type, 8>::type;
using bf8x16_ocp_t = typename vector_type<bf8_ocp_t::type, 16>::type;
using bf8x32_ocp_t = typename vector_type<bf8_ocp_t::type, 32>::type;
using bf8x64_ocp_t = typename vector_type<bf8_ocp_t::type, 64>::type;
#if CK_FP8_TYPE_OCP
// f8
using f8x2_t = f8x2_ocp_t;
using f8x4_t = f8x4_ocp_t;
using f8x8_t = f8x8_ocp_t;
using f8x16_t = f8x16_ocp_t;
using f8x32_t = f8x32_ocp_t;
using f8x64_t = f8x64_ocp_t;
// bf8
using bf8x2_t = bf8x2_ocp_t;
using bf8x4_t = bf8x4_ocp_t;
using bf8x8_t = bf8x8_ocp_t;
using bf8x16_t = bf8x16_ocp_t;
using bf8x32_t = bf8x32_ocp_t;
using bf8x64_t = bf8x64_ocp_t;
#elif CK_FP8_TYPE_FNUZ
// f8
using f8x2_t = f8x2_fnuz_t;
using f8x4_t = f8x4_fnuz_t;
using f8x8_t = f8x8_fnuz_t;
using f8x16_t = f8x16_fnuz_t;
using f8x32_t = f8x32_fnuz_t;
using f8x64_t = f8x64_fnuz_t;
// bf8
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
using bf8x2_t = bf8x2_fnuz_t;
using bf8x4_t = bf8x4_fnuz_t;
using bf8x8_t = bf8x8_fnuz_t;
using bf8x16_t = bf8x16_fnuz_t;
using bf8x32_t = bf8x32_fnuz_t;
using bf8x64_t = bf8x64_fnuz_t;
#endif
// u8
// i8
using uint8x2_t = typename vector_type<uint8_t, 2>::type;
......@@ -1107,7 +1210,7 @@ struct NumericLimits<int4_t>
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct NumericLimits<f8_t>
struct NumericLimits<f8_fnuz_t>
{
// negative zero nan mode with exp bias = 8
static constexpr uint8_t binary_min = 0x08; // 0b00001000
......@@ -1120,17 +1223,17 @@ struct NumericLimits<f8_t>
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); }
__host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); }
__host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); }
__host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); }
__host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); }
};
template <>
struct NumericLimits<bf8_t>
struct NumericLimits<bf8_fnuz_t>
{
// negative zero nan mode with exp bias = 16
static constexpr uint8_t binary_min = 0x04; // 0b00000100
......@@ -1143,13 +1246,59 @@ struct NumericLimits<bf8_t>
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); }
__host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); }
__host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); }
__host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); }
__host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); }
__host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); }
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
__host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); }
};
template <>
struct NumericLimits<f8_ocp_t>
{
static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6
static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448
static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448
static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111
__host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast<f8_ocp_t>(binary_min); }
__host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast<f8_ocp_t>(binary_max); }
__host__ __device__ static constexpr f8_ocp_t Lowest()
{
return bit_cast<f8_ocp_t>(binary_lowest);
}
__host__ __device__ static constexpr f8_ocp_t QuietNaN()
{
return bit_cast<f8_ocp_t>(binary_qnan);
}
};
template <>
struct NumericLimits<bf8_ocp_t>
{
static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14
static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344
static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344
static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101
__host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast<bf8_ocp_t>(binary_min); }
__host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast<bf8_ocp_t>(binary_max); }
__host__ __device__ static constexpr bf8_ocp_t Lowest()
{
return bit_cast<bf8_ocp_t>(binary_lowest);
}
__host__ __device__ static constexpr bf8_ocp_t QuietNaN()
{
return bit_cast<bf8_ocp_t>(binary_qnan);
}
};
template <typename T>
......@@ -1192,7 +1341,7 @@ struct NumericUtils<half_t>
};
template <>
struct NumericUtils<f8_t>
struct NumericUtils<f8_fnuz_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
......@@ -1201,11 +1350,27 @@ struct NumericUtils<f8_t>
};
template <>
struct NumericUtils<bf8_t>
struct NumericUtils<bf8_fnuz_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
static constexpr int bias = 16; // negative zero nan mode
// static constexpr int bias = 15; // ieee mode
};
template <>
struct NumericUtils<f8_ocp_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
static constexpr int bias = 7;
};
template <>
struct NumericUtils<bf8_ocp_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
static constexpr int bias = 15;
};
} // namespace ck
......@@ -163,7 +163,7 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
......@@ -189,33 +189,35 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp16 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
{
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
return f8_convert_sr<f8_fnuz_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<half_t,
f8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
......@@ -240,28 +242,32 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<float,
bf8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp16 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x)
{
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_sr<bf8_t>(type_convert<float>(x));
return f8_convert_sr<bf8_fnuz_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<half_t,
bf8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
......@@ -271,7 +277,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, float>(float x)
{
#if defined(__gfx94__)
union
......@@ -296,32 +302,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp16 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x)
inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, half_t>(half_t x)
{
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_rne<f8_t>(type_convert<float>(x));
return f8_convert_rne<f8_fnuz_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<half_t,
f8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp32 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, float>(float x)
{
#if defined(__gfx94__)
union
......@@ -345,44 +353,48 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<float,
bf8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp16 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x)
inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, half_t>(half_t x)
{
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_rne<bf8_t>(type_convert<float>(x));
return f8_convert_rne<bf8_fnuz_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<half_t,
bf8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x);
return f8_convert_sr<f8_fnuz_t>(x);
#else
return f8_convert_rne<f8_t>(x);
return f8_convert_rne<f8_fnuz_t>(x);
#endif
}
// convert fp8 to fp32
template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
inline __host__ __device__ float type_convert<float, f8_fnuz_t>(f8_fnuz_t x)
{
#if defined(__gfx94__)
float fval;
......@@ -392,7 +404,7 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
return fval;
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x);
return utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(x);
#endif
}
......@@ -404,14 +416,14 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else
constexpr bool negative_zero_nan = true;
const auto f8x2_v = vector_type<f8_t, 2>(x);
const auto f8x2_v = vector_type<f8_fnuz_t, 2>(x);
vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<0>{}]);
utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_fnuz_t>()[Number<0>{}]);
f32x2_v.template AsType<float>()(Number<1>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<1>{}]);
utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_fnuz_t>()[Number<1>{}]);
return f32x2_v.template AsType<float2_t>()[Number<0>{}];
#endif
}
......@@ -428,42 +440,42 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x);
return f8_convert_sr<f8_fnuz_t>(x);
#else
return f8_convert_rne<f8_t>(x);
return f8_convert_rne<f8_fnuz_t>(x);
#endif
}
// convert fp8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
inline __host__ __device__ half_t type_convert<half_t, f8_fnuz_t>(f8_fnuz_t x)
{
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
return utils::cast_from_f8<f8_fnuz_t, half_t, negative_zero_nan>(x);
#endif
}
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x);
return f8_convert_sr<bf8_fnuz_t>(x);
#else
return f8_convert_rne<bf8_t>(x);
return f8_convert_rne<bf8_fnuz_t>(x);
#endif
}
// convert bf8 to fp32
template <>
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
inline __host__ __device__ float type_convert<float, bf8_fnuz_t>(bf8_fnuz_t x)
{
#if defined(__gfx94__)
float fval;
......@@ -473,31 +485,31 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
return fval;
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
return utils::cast_from_f8<bf8_fnuz_t, float, negative_zero_nan>(x);
#endif
}
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x);
return f8_convert_sr<bf8_fnuz_t>(x);
#else
return f8_convert_rne<bf8_t>(x);
return f8_convert_rne<bf8_fnuz_t>(x);
#endif
}
// convert bf8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
{
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
return utils::cast_from_f8<bf8_fnuz_t, half_t, negative_zero_nan>(x);
#endif
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment