Commit 0a8edad5 authored by Umang Yadav's avatar Umang Yadav
Browse files

works except constexpr

parent d734871c
......@@ -27,7 +27,7 @@
#include <half/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/fp8e4m3fnuz.hpp>
#include <migraphx/migraphx_float8.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
};
template <>
struct common_type<migraphx::fp8e4m3fnuz, migraphx::half>
struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx::half>
{
using type = float;
};
template <>
struct common_type<migraphx::half, migraphx::fp8e4m3fnuz>
struct common_type<migraphx::half, migraphx_fp8::fp8e4m3fnuz>
{
using type = float;
};
......
This diff is collapsed.
......@@ -25,8 +25,22 @@
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier"
namespace migraphx_hip_f8_impl {
namespace detail {
template <bool B, class T, class F>
struct conditional
{
using type = T;
};
template <class T, class F>
struct conditional<false, T, F>
{
using type = F;
};
} // namespace detail
// #ifdef __HIP_PLATFORM_HCC__
// __device__ inline int clz(uint32_t x) { return __clz(x); }
......@@ -35,12 +49,10 @@ namespace migraphx_hip_f8_impl {
// #endif
template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
MIGRAPHX_HIP_HOST_DEVICE uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
{
constexpr bool is_half = migraphx::is_same<T, migraphx::half>{};
constexpr bool is_float = migraphx::is_same<T, float>{};
static_assert(wm + we == 7, "wm+we==7");
static_assert(is_half || is_float, "Only half and float can be cast to f8");
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
uint32_t x;
......@@ -215,29 +227,12 @@ this case, the fp16 mantissa should be shift left by 1 */
}
template <int wm, int we, typename T, bool negative_zero_nan>
MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
{
constexpr bool is_half = migraphx::is_same<T, migraphx::half>{};
constexpr bool is_float = migraphx::is_same<T, float>{};
static_assert(is_half || is_float, "only half and float are supported");
constexpr int weo = is_half ? 5 : 8;
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
constexpr int weo = 8;
constexpr int wmo = 23;
T fInf, fNegInf, fNaN, fNeg0;
if(is_half)
{
const uint16_t ihInf = 0x7C00;
const uint16_t ihNegInf = 0xFC00;
const uint16_t ihNaN = 0x7C01;
const uint16_t ihNeg0 = 0x8000;
fInf = reinterpret_cast<const migraphx::half&>(ihInf);
fNegInf = reinterpret_cast<const migraphx::half&>(ihNegInf);
fNaN = reinterpret_cast<const migraphx::half&>(ihNaN);
fNeg0 = reinterpret_cast<const migraphx::half&>(ihNeg0);
}
else if(is_float)
{
const uint32_t ifInf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001;
......@@ -246,7 +241,6 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
fNegInf = reinterpret_cast<const float&>(ifNegInf);
fNaN = reinterpret_cast<const float&>(ifNaN);
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
}
if(x == 0)
return 0;
......@@ -266,12 +260,7 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
if(exponent == ((1 << we) - 1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
}
typename migraphx::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
if(we == 5 && is_half && !negative_zero_nan)
{
retval = x << 8;
return reinterpret_cast<const T&>(retval);
}
typename detail::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
......
......@@ -34,7 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/fp8e4m3fnuz.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
......@@ -54,7 +54,7 @@ struct MIGRAPHX_EXPORT shape
m(half_type, half) \
m(float_type, float) \
m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \
m(float8_type, migraphx_fp8::fp8e4m3fnuz) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
......
......@@ -25,10 +25,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <migraphx/fp8e4m3fnuz.hpp>
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/migraphx_float8.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -63,9 +63,9 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx_fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx_fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx_fp8::fp8e4m3fnuz)
template <class T>
using accumulator_type =
......
......@@ -40,7 +40,7 @@
#include <migraphx/json.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/fp8e4m3fnuz.hpp>
#include <migraphx/migraphx_float8.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#endif
......@@ -145,7 +145,7 @@ struct npy_format_descriptor<half>
};
template <>
struct npy_format_descriptor<migraphx::fp8e4m3fnuz>
struct npy_format_descriptor<migraphx_fp8::fp8e4m3fnuz>
{
static std::string format()
{
......
......@@ -60,7 +60,7 @@ endif()
include(Embed)
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/ EXTRA_HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/fp8e4m3fnuz.hpp EXTRA_HEADERS_RELATIVE ${CMAKE_SOURCE_DIR}/src/include)
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/ EXTRA_HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/migraphx_float8.hpp ${CMAKE_SOURCE_DIR}/src/include/migraphx/migraphx_hip_f8_impl.hpp EXTRA_HEADERS_RELATIVE ${CMAKE_SOURCE_DIR}/src/include ${CMAKE_SOURCE_DIR}/src/include)
configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp)
file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp)
......
......@@ -35,7 +35,7 @@ namespace migraphx {
namespace math {
constexpr float as_float(migraphx::half x) { return x; }
constexpr float as_float(migraphx::fp8e4m3fnuz x) { return x; }
constexpr float as_float(migraphx_fp8::fp8e4m3fnuz x) { return x; }
template <class T>
constexpr T as_float(T x)
......@@ -78,15 +78,15 @@ constexpr T as_float(T x)
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::fp8e4m3fnuz x, Ts... xs) \
MIGRAPHX_RETURNS(migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
auto __device__ name(migraphx_fp8::fp8e4m3fnuz x, Ts... xs) MIGRAPHX_RETURNS( \
migraphx_fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname) \
inline auto __device__ name(migraphx::fp8e4m3fnuz x, migraphx::fp8e4m3fnuz y) \
-> migraphx::fp8e4m3fnuz \
inline auto __device__ name(migraphx_fp8::fp8e4m3fnuz x, migraphx_fp8::fp8e4m3fnuz y) \
-> migraphx_fp8::fp8e4m3fnuz \
{ \
return migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y))); \
return migraphx_fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y))); \
}
// Template with two overloads for math functions, one for half2 type and one for more generic
......
......@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#include <migraphx/fp8e4m3fnuz.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
......@@ -231,7 +231,8 @@ constexpr unsigned long int_max(unsigned long n)
template <class T,
MIGRAPHX_REQUIRES(is_integral<T>{} or is_floating_point<T>{} or
is_same<T, migraphx::half>{} or is_same<T, migraphx::fp8e4m3fnuz>{})>
is_same<T, migraphx::half>{} or
is_same<T, migraphx_fp8::fp8e4m3fnuz>{})>
constexpr T numeric_max()
{
if constexpr(is_integral<T>{})
......@@ -247,8 +248,8 @@ constexpr T numeric_max()
return __FLT_MAX__;
else if constexpr(is_same<T, migraphx::half>{})
return __FLT16_MAX__;
else if constexpr(is_same<T, migraphx::fp8e4m3fnuz>{})
return T{0x7F, migraphx::fp8e4m3fnuz::from_bits()};
else if constexpr(is_same<T, migraphx_fp8::fp8e4m3fnuz>{})
return migraphx_fp8::F8_Max<T>();
else
return 0;
}
......@@ -263,8 +264,8 @@ constexpr T numeric_lowest()
else
return -numeric_max<T>() - 1;
}
else if constexpr(is_same<T, migraphx::fp8e4m3fnuz>{})
return T{0xFF, migraphx::fp8e4m3fnuz::from_bits()};
else if constexpr(is_same<T, migraphx_fp8::fp8e4m3fnuz>{})
return migraphx_fp8::F8_Lowest<T>();
else
{
return -numeric_max<T>();
......
......@@ -23,7 +23,7 @@
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#include <migraphx/fp8e4m3fnuz.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/kernels/hip.hpp>
namespace migraphx {
......
......@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#include "migraphx/kernels/type_traits.hpp"
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/vec.hpp>
......@@ -237,7 +237,7 @@ template <index_int N, index_int Axis, class T>
__device__ __host__ auto vectorize_tensor(T x)
{
constexpr auto shape = get_shape_c<T>{};
if constexpr(is_same<typename T::type, migraphx::fp8e4m3fnuz>{})
if constexpr(is_same<typename T::type, migraphx_fp8::fp8e4m3fnuz>{})
return x;
else if constexpr(shape.lens[Axis] == 1)
return x;
......
......@@ -351,7 +351,7 @@ TEST_CASE(compile_math)
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
continue;
auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type or t == migraphx::shape::float8_type)
if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::");
data_types.push_back(name);
if(t != migraphx::shape::float8_type)
......@@ -402,7 +402,7 @@ TEST_CASE(assert_type_min_max)
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
continue;
auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type or t == migraphx::shape::float8_type)
if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::");
migraphx::shape::visit(t, [&](auto as) {
......
......@@ -37,7 +37,7 @@
m(half_type, half) \
m(float_type, float) \
m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \
m(float8_type, migraphx_fp8::fp8e4m3fnuz) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment