Commit 0a8edad5 authored by Umang Yadav's avatar Umang Yadav
Browse files

works except constexpr

parent d734871c
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <half/half.hpp> #include <half/half.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/fp8e4m3fnuz.hpp> #include <migraphx/migraphx_float8.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT ...@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
}; };
template <> template <>
struct common_type<migraphx::fp8e4m3fnuz, migraphx::half> struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx::half>
{ {
using type = float; using type = float;
}; };
template <> template <>
struct common_type<migraphx::half, migraphx::fp8e4m3fnuz> struct common_type<migraphx::half, migraphx_fp8::fp8e4m3fnuz>
{ {
using type = float; using type = float;
}; };
......
This diff is collapsed.
...@@ -25,8 +25,22 @@ ...@@ -25,8 +25,22 @@
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast" #pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier"
namespace migraphx_hip_f8_impl { namespace migraphx_hip_f8_impl {
namespace detail {
template <bool B, class T, class F>
struct conditional
{
using type = T;
};
template <class T, class F>
struct conditional<false, T, F>
{
using type = F;
};
} // namespace detail
// #ifdef __HIP_PLATFORM_HCC__ // #ifdef __HIP_PLATFORM_HCC__
// __device__ inline int clz(uint32_t x) { return __clz(x); } // __device__ inline int clz(uint32_t x) { return __clz(x); }
...@@ -35,12 +49,10 @@ namespace migraphx_hip_f8_impl { ...@@ -35,12 +49,10 @@ namespace migraphx_hip_f8_impl {
// #endif // #endif
template <int wm, int we, typename T, bool negative_zero_nan, bool clip> template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
MIGRAPHX_HIP_HOST_DEVICE uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
{ {
constexpr bool is_half = migraphx::is_same<T, migraphx::half>{};
constexpr bool is_float = migraphx::is_same<T, float>{};
static_assert(wm + we == 7, "wm+we==7"); static_assert(wm + we == 7, "wm+we==7");
static_assert(is_half || is_float, "Only half and float can be cast to f8");
const int mfmt = (sizeof(T) == 4) ? 23 : 10; const int mfmt = (sizeof(T) == 4) ? 23 : 10;
uint32_t x; uint32_t x;
...@@ -215,38 +227,20 @@ this case, the fp16 mantissa should be shift left by 1 */ ...@@ -215,38 +227,20 @@ this case, the fp16 mantissa should be shift left by 1 */
} }
template <int wm, int we, typename T, bool negative_zero_nan> template <int wm, int we, typename T, bool negative_zero_nan>
MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x) MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
{ {
constexpr bool is_half = migraphx::is_same<T, migraphx::half>{}; constexpr int weo = 8;
constexpr bool is_float = migraphx::is_same<T, float>{}; constexpr int wmo = 23;
static_assert(is_half || is_float, "only half and float are supported");
constexpr int weo = is_half ? 5 : 8;
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
T fInf, fNegInf, fNaN, fNeg0; T fInf, fNegInf, fNaN, fNeg0;
if(is_half) const uint32_t ifInf = 0x7F800000;
{ const uint32_t ifNegInf = 0xFF800000;
const uint16_t ihInf = 0x7C00; const uint32_t ifNaN = 0x7F800001;
const uint16_t ihNegInf = 0xFC00; const uint32_t ifNeg0 = 0x80000000;
const uint16_t ihNaN = 0x7C01; fInf = reinterpret_cast<const float&>(ifInf);
const uint16_t ihNeg0 = 0x8000; fNegInf = reinterpret_cast<const float&>(ifNegInf);
fInf = reinterpret_cast<const migraphx::half&>(ihInf); fNaN = reinterpret_cast<const float&>(ifNaN);
fNegInf = reinterpret_cast<const migraphx::half&>(ihNegInf); fNeg0 = reinterpret_cast<const float&>(ifNeg0);
fNaN = reinterpret_cast<const migraphx::half&>(ihNaN);
fNeg0 = reinterpret_cast<const migraphx::half&>(ihNeg0);
}
else if(is_float)
{
const uint32_t ifInf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001;
const uint32_t ifNeg0 = 0x80000000;
fInf = reinterpret_cast<const float&>(ifInf);
fNegInf = reinterpret_cast<const float&>(ifNegInf);
fNaN = reinterpret_cast<const float&>(ifNaN);
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
}
if(x == 0) if(x == 0)
return 0; return 0;
...@@ -266,12 +260,7 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x) ...@@ -266,12 +260,7 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
if(exponent == ((1 << we) - 1)) if(exponent == ((1 << we) - 1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
} }
typename migraphx::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval; typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
if(we == 5 && is_half && !negative_zero_nan)
{
retval = x << 8;
return reinterpret_cast<const T&>(retval);
}
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
......
...@@ -34,7 +34,7 @@ ...@@ -34,7 +34,7 @@
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/fp8e4m3fnuz.hpp> #include <migraphx/migraphx_float8.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -54,7 +54,7 @@ struct MIGRAPHX_EXPORT shape ...@@ -54,7 +54,7 @@ struct MIGRAPHX_EXPORT shape
m(half_type, half) \ m(half_type, half) \
m(float_type, float) \ m(float_type, float) \
m(double_type, double) \ m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \ m(float8_type, migraphx_fp8::fp8e4m3fnuz) \
m(uint8_type, uint8_t) \ m(uint8_type, uint8_t) \
m(int8_type, int8_t) \ m(int8_type, int8_t) \
m(uint16_type, uint16_t) \ m(uint16_type, uint16_t) \
......
...@@ -25,10 +25,10 @@ ...@@ -25,10 +25,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP #define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <migraphx/fp8e4m3fnuz.hpp>
#include <type_traits> #include <type_traits>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/migraphx_float8.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -63,9 +63,9 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half) ...@@ -63,9 +63,9 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, fp8e4m3fnuz) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx_fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, fp8e4m3fnuz) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx_fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, fp8e4m3fnuz) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx_fp8::fp8e4m3fnuz)
template <class T> template <class T>
using accumulator_type = using accumulator_type =
......
...@@ -40,7 +40,7 @@ ...@@ -40,7 +40,7 @@
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/fp8e4m3fnuz.hpp> #include <migraphx/migraphx_float8.hpp>
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#endif #endif
...@@ -145,7 +145,7 @@ struct npy_format_descriptor<half> ...@@ -145,7 +145,7 @@ struct npy_format_descriptor<half>
}; };
template <> template <>
struct npy_format_descriptor<migraphx::fp8e4m3fnuz> struct npy_format_descriptor<migraphx_fp8::fp8e4m3fnuz>
{ {
static std::string format() static std::string format()
{ {
......
...@@ -60,7 +60,7 @@ endif() ...@@ -60,7 +60,7 @@ endif()
include(Embed) include(Embed)
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/ EXTRA_HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/fp8e4m3fnuz.hpp EXTRA_HEADERS_RELATIVE ${CMAKE_SOURCE_DIR}/src/include) add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/ EXTRA_HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/migraphx_float8.hpp ${CMAKE_SOURCE_DIR}/src/include/migraphx/migraphx_hip_f8_impl.hpp EXTRA_HEADERS_RELATIVE ${CMAKE_SOURCE_DIR}/src/include ${CMAKE_SOURCE_DIR}/src/include)
configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp) configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp)
file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp) file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp)
......
...@@ -35,7 +35,7 @@ namespace migraphx { ...@@ -35,7 +35,7 @@ namespace migraphx {
namespace math { namespace math {
constexpr float as_float(migraphx::half x) { return x; } constexpr float as_float(migraphx::half x) { return x; }
constexpr float as_float(migraphx::fp8e4m3fnuz x) { return x; } constexpr float as_float(migraphx_fp8::fp8e4m3fnuz x) { return x; }
template <class T> template <class T>
constexpr T as_float(T x) constexpr T as_float(T x)
...@@ -76,17 +76,17 @@ constexpr T as_float(T x) ...@@ -76,17 +76,17 @@ constexpr T as_float(T x)
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...)) MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \ #define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \ template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::fp8e4m3fnuz x, Ts... xs) \ auto __device__ name(migraphx_fp8::fp8e4m3fnuz x, Ts... xs) MIGRAPHX_RETURNS( \
MIGRAPHX_RETURNS(migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...))) migraphx_fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname) \ #define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname) \
inline auto __device__ name(migraphx::fp8e4m3fnuz x, migraphx::fp8e4m3fnuz y) \ inline auto __device__ name(migraphx_fp8::fp8e4m3fnuz x, migraphx_fp8::fp8e4m3fnuz y) \
-> migraphx::fp8e4m3fnuz \ -> migraphx_fp8::fp8e4m3fnuz \
{ \ { \
return migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y))); \ return migraphx_fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y))); \
} }
// Template with two overloads for math functions, one for half2 type and one for more generic // Template with two overloads for math functions, one for half2 type and one for more generic
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#include <migraphx/fp8e4m3fnuz.hpp> #include <migraphx/migraphx_float8.hpp>
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
...@@ -231,7 +231,8 @@ constexpr unsigned long int_max(unsigned long n) ...@@ -231,7 +231,8 @@ constexpr unsigned long int_max(unsigned long n)
template <class T, template <class T,
MIGRAPHX_REQUIRES(is_integral<T>{} or is_floating_point<T>{} or MIGRAPHX_REQUIRES(is_integral<T>{} or is_floating_point<T>{} or
is_same<T, migraphx::half>{} or is_same<T, migraphx::fp8e4m3fnuz>{})> is_same<T, migraphx::half>{} or
is_same<T, migraphx_fp8::fp8e4m3fnuz>{})>
constexpr T numeric_max() constexpr T numeric_max()
{ {
if constexpr(is_integral<T>{}) if constexpr(is_integral<T>{})
...@@ -247,8 +248,8 @@ constexpr T numeric_max() ...@@ -247,8 +248,8 @@ constexpr T numeric_max()
return __FLT_MAX__; return __FLT_MAX__;
else if constexpr(is_same<T, migraphx::half>{}) else if constexpr(is_same<T, migraphx::half>{})
return __FLT16_MAX__; return __FLT16_MAX__;
else if constexpr(is_same<T, migraphx::fp8e4m3fnuz>{}) else if constexpr(is_same<T, migraphx_fp8::fp8e4m3fnuz>{})
return T{0x7F, migraphx::fp8e4m3fnuz::from_bits()}; return migraphx_fp8::F8_Max<T>();
else else
return 0; return 0;
} }
...@@ -263,8 +264,8 @@ constexpr T numeric_lowest() ...@@ -263,8 +264,8 @@ constexpr T numeric_lowest()
else else
return -numeric_max<T>() - 1; return -numeric_max<T>() - 1;
} }
else if constexpr(is_same<T, migraphx::fp8e4m3fnuz>{}) else if constexpr(is_same<T, migraphx_fp8::fp8e4m3fnuz>{})
return T{0xFF, migraphx::fp8e4m3fnuz::from_bits()}; return migraphx_fp8::F8_Lowest<T>();
else else
{ {
return -numeric_max<T>(); return -numeric_max<T>();
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#include <migraphx/fp8e4m3fnuz.hpp> #include <migraphx/migraphx_float8.hpp>
#include <migraphx/kernels/hip.hpp> #include <migraphx/kernels/hip.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP #ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP #define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#include "migraphx/kernels/type_traits.hpp" #include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/tensor_view.hpp> #include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/vec.hpp> #include <migraphx/kernels/vec.hpp>
...@@ -237,7 +237,7 @@ template <index_int N, index_int Axis, class T> ...@@ -237,7 +237,7 @@ template <index_int N, index_int Axis, class T>
__device__ __host__ auto vectorize_tensor(T x) __device__ __host__ auto vectorize_tensor(T x)
{ {
constexpr auto shape = get_shape_c<T>{}; constexpr auto shape = get_shape_c<T>{};
if constexpr(is_same<typename T::type, migraphx::fp8e4m3fnuz>{}) if constexpr(is_same<typename T::type, migraphx_fp8::fp8e4m3fnuz>{})
return x; return x;
else if constexpr(shape.lens[Axis] == 1) else if constexpr(shape.lens[Axis] == 1)
return x; return x;
......
...@@ -351,7 +351,7 @@ TEST_CASE(compile_math) ...@@ -351,7 +351,7 @@ TEST_CASE(compile_math)
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t)) if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
continue; continue;
auto name = migraphx::shape::cpp_type(t); auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type or t == migraphx::shape::float8_type) if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::"); name.insert(0, "migraphx::");
data_types.push_back(name); data_types.push_back(name);
if(t != migraphx::shape::float8_type) if(t != migraphx::shape::float8_type)
...@@ -402,7 +402,7 @@ TEST_CASE(assert_type_min_max) ...@@ -402,7 +402,7 @@ TEST_CASE(assert_type_min_max)
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t)) if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
continue; continue;
auto name = migraphx::shape::cpp_type(t); auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type or t == migraphx::shape::float8_type) if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::"); name.insert(0, "migraphx::");
migraphx::shape::visit(t, [&](auto as) { migraphx::shape::visit(t, [&](auto as) {
......
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
m(half_type, half) \ m(half_type, half) \
m(float_type, float) \ m(float_type, float) \
m(double_type, double) \ m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \ m(float8_type, migraphx_fp8::fp8e4m3fnuz) \
m(uint8_type, uint8_t) \ m(uint8_type, uint8_t) \
m(int8_type, int8_t) \ m(int8_type, int8_t) \
m(uint16_type, uint16_t) \ m(uint16_type, uint16_t) \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment