/**
 * \file
 * \brief bf16_internal.h provides struct for __nv_bfloat16 types
 */

/**
 * \defgroup CUDA_INTRINSIC_BFLOAT16 bfloat16 Precision Intrinsics
 * This section describes nv_bfloat16 precision intrinsic functions.
 * To use these functions, include the header file \p cuda_bf16.h in your program.
 */

/**
 * \defgroup CUDA_INTRINSIC_BFLOAT16_ARITH Bfloat16 Arithmetic Functions
 * \ingroup CUDA_INTRINSIC_BFLOAT16
 * To use these functions, include the header file \p cuda_bf16.h in your program.
 */

/**
 * \defgroup CUDA_INTRINSIC_BFLOAT16_COMP Bfloat16 Comparision Functions
 * \ingroup CUDA_INTRINSIC_BFLOAT16
 * To use these functions, include the header file \p cuda_bf16.h in your program.
 */

/**
 * \defgroup CUDA_INTRINSIC_BFLOAT162_COMP Bfloat162 Comparision Functions
 * \ingroup CUDA_INTRINSIC_BFLOAT16
 * To use these functions, include the header file \p cuda_bf16.h in your program.
 */

/**
 * \defgroup CUDA_INTRINSIC_BFLOAT162_ARITH Bfloat162 Arithmetic Functions
 * \ingroup CUDA_INTRINSIC_BFLOAT16
 * To use these functions, include the header file \p cuda_bf16.h in your program.
 */

/**
 * \defgroup CUDA_INTRINSIC_BFLOAT16_CONV Bfloat16 Conversion Functions
 * \ingroup CUDA_INTRINSIC_BFLOAT16
 * To use these functions, include the header file \p cuda_bf16.h in your program.
 */

/**
 * \defgroup CUDA_INTRINSIC_BFLOAT162_CONV Bfloat162 Conversion Functions
 * \ingroup CUDA_INTRINSIC_BFLOAT16
 * To use these functions, include the header file \p cuda_bf16.h in your program.
 */

/**
 * \defgroup CUDA_INTRINSIC_BFLOAT16_MATH Bfloat16 Math Functions
 * \ingroup CUDA_INTRINSIC_BFLOAT16
 * To use these functions, include the header file \p cuda_bf16.h in your program.
 */

/**
 * \defgroup CUDA_INTRINSIC_BFLOAT162_MATH Bfloat162 Math Functions
 * \ingroup CUDA_INTRINSIC_BFLOAT16
 * To use these functions, include the header file \p cuda_bf16.h in your program.
 */

#ifndef __BF16_INTERNAL_H__
#define __BF16_INTERNAL_H__

#if defined(__cplusplus) && !defined(__CUDACC__)
#include "vector_types.h" // float2 etc
#endif

#if defined(__CUDACC_RTC__)
    #define __HOST_DEVICE__ __device__
#else
    #include <climits>
    #define __HOST_DEVICE__ __host__ __device__
#endif

#if __cplusplus < 201103L || !defined(__CUDACC__)

// If this is a C compiler, C++ compiler below C++11, or a host-only compiler, we only
// include a minimal definition of __nv_bfloat16

#include <stdint.h>
/*! \brief Struct to represent a 16 bit brain floating point number. */
typedef struct
{
    uint16_t data;
} __nv_bfloat16;

#else // __cplusplus < 201103L || !defined(__CUDACC__)

#include <cmath>
#include <cstddef>
#include <cstdint>
#include <ostream>
#include <type_traits>

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wshadow"
struct __nv_bfloat16
{
    uint16_t data;

    enum truncate_t
    {
        truncate
    };

    __HOST_DEVICE__ __nv_bfloat16() = default;

    // round upper 16 bits of IEEE float to convert to bfloat16
    explicit __HOST_DEVICE__ __nv_bfloat16(float f)
        : data(float_to_bfloat16(f))
    {
    }

    explicit __HOST_DEVICE__ __nv_bfloat16(float f, truncate_t)
        : data(truncate_float_to_bfloat16(f))
    {
    }

    // zero extend lower 16 bits of bfloat16 to convert to IEEE float
    __HOST_DEVICE__ operator float() const
    {
        union
        {
            uint32_t int32;
            float    fp32;
        } u = {uint32_t(data) << 16};
        return u.fp32;
    }

    __HOST_DEVICE__ __nv_bfloat16 &operator=(const float& f)
    {
       data = float_to_bfloat16(f);
       return *this;
    }

    static  __HOST_DEVICE__ __nv_bfloat16 round_to_bfloat16(float f)
    {
        __nv_bfloat16 output;
        output.data = float_to_bfloat16(f);
        return output;
    }

    static  __HOST_DEVICE__ __nv_bfloat16 round_to_bfloat16(float f, truncate_t)
    {
        __nv_bfloat16 output;
        output.data = truncate_float_to_bfloat16(f);
        return output;
    }

private:
    static __HOST_DEVICE__ uint16_t float_to_bfloat16(float f)
    {
        union
        {
            float    fp32;
            uint32_t int32;
        } u = {f};
        if(~u.int32 & 0x7f800000)
        {
            // When the exponent bits are not all 1s, then the value is zero, normal,
            // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
            // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
            // This causes the bfloat16's mantissa to be incremented by 1 if the 16
            // least significant bits of the float mantissa are greater than 0x8000,
            // or if they are equal to 0x8000 and the least significant bit of the
            // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
            // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
            // has the value 0x7f, then incrementing it causes it to become 0x00 and
            // the exponent is incremented by one, which is the next higher FP value
            // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
            // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
            // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
            // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
            // incrementing it causes it to become an exponent of 0xFF and a mantissa
            // of 0x00, which is Inf, the next higher value to the unrounded value.
            u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
        }
        else if(u.int32 & 0xffff)
        {
            // When all of the exponent bits are 1, the value is Inf or NaN.
            // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
            // mantissa bit. Quiet NaN is indicated by the most significant mantissa
            // bit being 1. Signaling NaN is indicated by the most significant
            // mantissa bit being 0 but some other bit(s) being 1. If any of the
            // lower 16 bits of the mantissa are 1, we set the least significant bit
            // of the bfloat16 mantissa, in order to preserve signaling NaN in case
            // the bloat16's mantissa bits are all 0.
            u.int32 |= 0x10000; // Preserve signaling NaN
        }
        return uint16_t(u.int32 >> 16);
    }

    // Truncate instead of rounding, preserving SNaN
    static __HOST_DEVICE__ uint16_t truncate_float_to_bfloat16(float f)
    {
        union
        {
            float    fp32;
            uint32_t int32;
        } u = {f};
        return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
    }
};
#pragma clang diagnostic pop

typedef struct
{
    uint16_t data;
} __nv_bfloat16_public;

static_assert(std::is_standard_layout<__nv_bfloat16>{},
              "__nv_bfloat16 is not a standard layout type, and thus is "
              "incompatible with C.");

static_assert(std::is_trivial<__nv_bfloat16>{},
              "__nv_bfloat16 is not a trivial type, and thus is "
              "incompatible with C.");
#if !defined(__CUDACC_RTC__)
static_assert(sizeof(__nv_bfloat16) == sizeof(__nv_bfloat16_public)
                  && offsetof(__nv_bfloat16, data) == offsetof(__nv_bfloat16_public, data),
              "internal __nv_bfloat16 does not match public __nv_bfloat16");

inline std::ostream& operator<<(std::ostream& os, const __nv_bfloat16& bf16)
{
  return os << float(bf16);
}
#endif

inline __HOST_DEVICE__ __nv_bfloat16 operator+(__nv_bfloat16 a)
{
    return a;
}
inline __HOST_DEVICE__ __nv_bfloat16 operator-(__nv_bfloat16 a)
{
    a.data ^= 0x8000;
    return a;
}
inline __HOST_DEVICE__ __nv_bfloat16 operator+(__nv_bfloat16 a, __nv_bfloat16 b)
{
    return __nv_bfloat16(float(a) + float(b));
}
inline __HOST_DEVICE__ __nv_bfloat16 operator-(__nv_bfloat16 a, __nv_bfloat16 b)
{
    return __nv_bfloat16(float(a) - float(b));
}
inline __HOST_DEVICE__ __nv_bfloat16 operator*(__nv_bfloat16 a, __nv_bfloat16 b)
{
    return __nv_bfloat16(float(a) * float(b));
}
inline __HOST_DEVICE__ __nv_bfloat16 operator/(__nv_bfloat16 a, __nv_bfloat16 b)
{
    return __nv_bfloat16(float(a) / float(b));
}
inline __HOST_DEVICE__ bool operator<(__nv_bfloat16 a, __nv_bfloat16 b)
{
    return float(a) < float(b);
}
inline __HOST_DEVICE__ bool operator==(__nv_bfloat16 a, __nv_bfloat16 b)
{
    return float(a) == float(b);
}
inline __HOST_DEVICE__ bool operator>(__nv_bfloat16 a, __nv_bfloat16 b)
{
    return b < a;
}
inline __HOST_DEVICE__ bool operator<=(__nv_bfloat16 a, __nv_bfloat16 b)
{
    return !(a > b);
}
inline __HOST_DEVICE__ bool operator!=(__nv_bfloat16 a, __nv_bfloat16 b)
{
    return !(a == b);
}
inline __HOST_DEVICE__ bool operator>=(__nv_bfloat16 a, __nv_bfloat16 b)
{
    return !(a < b);
}
inline __HOST_DEVICE__ __nv_bfloat16& operator+=(__nv_bfloat16& a, __nv_bfloat16 b)
{
    return a = a + b;
}
inline __HOST_DEVICE__ __nv_bfloat16& operator-=(__nv_bfloat16& a, __nv_bfloat16 b)
{
    return a = a - b;
}
inline __HOST_DEVICE__ __nv_bfloat16& operator*=(__nv_bfloat16& a, __nv_bfloat16 b)
{
    return a = a * b;
}
inline __HOST_DEVICE__ __nv_bfloat16& operator/=(__nv_bfloat16& a, __nv_bfloat16 b)
{
    return a = a / b;
}
inline __HOST_DEVICE__ __nv_bfloat16& operator++(__nv_bfloat16& a)
{
    return a += __nv_bfloat16(1.0f);
}
inline __HOST_DEVICE__ __nv_bfloat16& operator--(__nv_bfloat16& a)
{
    return a -= __nv_bfloat16(1.0f);
}
inline __HOST_DEVICE__ __nv_bfloat16 operator++(__nv_bfloat16& a, int)
{
    __nv_bfloat16 orig = a;
    ++a;
    return orig;
}
inline __HOST_DEVICE__ __nv_bfloat16 operator--(__nv_bfloat16& a, int)
{
    __nv_bfloat16 orig = a;
    --a;
    return orig;
}

namespace std
{
    constexpr __HOST_DEVICE__ bool isinf(__nv_bfloat16 a)
    {
        return !(~a.data & 0x7f80) && !(a.data & 0x7f);
    }
    constexpr __HOST_DEVICE__ bool isnan(__nv_bfloat16 a)
    {
        return !(~a.data & 0x7f80) && +(a.data & 0x7f);
    }
    constexpr __HOST_DEVICE__ bool iszero(__nv_bfloat16 a)
    {
        return !(a.data & 0x7fff);
    }
}

#endif // __cplusplus < 201103L || !defined(__CUDACC__)

// Since we are using unsigned short to represent data in bfloat16, it can be of different sizes on
// different machines. These naive checks should prevent some undefined behavior on systems which
// have different sizes for basic types.
#if !defined(__CUDACC_RTC__)
static_assert(CHAR_BIT == 8, "byte size should be of 8 bits");
#endif
static_assert(sizeof(unsigned short) == 2, "size of unsigned short should be 2 bytes");

/*! \brief Struct to represent two 16 bit brain floating point numbers. */
struct __nv_bfloat162 {
  __nv_bfloat16 x;
  __nv_bfloat16 y;
};

namespace{
/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_CONV
 * \brief Converts bfloat16 to float
 */
__HOST_DEVICE__ inline float __bfloat162float(__nv_bfloat16 a) {
  unsigned int uval = 0;
  uval = a.data << 16;
  union {
    unsigned int u32;
    float fp32;
  } u = {uval};
  return u.fp32;
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_CONV
 * \brief Converts float to bfloat16
 */
__HOST_DEVICE__ __nv_bfloat16 __float2bfloat16(float f) {
  __nv_bfloat16 ret;
  union {
    float fp32;
    unsigned int u32;
  } u = {f};
  if (~u.u32 & 0x7f800000) {
    // When the exponent bits are not all 1s, then the value is zero, normal,
    // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
    // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
    // This causes the bfloat16's mantissa to be incremented by 1 if the 16
    // least significant bits of the float mantissa are greater than 0x8000,
    // or if they are equal to 0x8000 and the least significant bit of the
    // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
    // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
    // has the value 0x7f, then incrementing it causes it to become 0x00 and
    // the exponent is incremented by one, which is the next higher FP value
    // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
    // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
    // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
    // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
    // incrementing it causes it to become an exponent of 0xFF and a mantissa
    // of 0x00, which is Inf, the next higher value to the unrounded value.
    u.u32 += 0x7fff + ((u.u32 >> 16) & 1);  // Round to nearest, round to even
  } else if (u.u32 & 0xffff) {
    // When all of the exponent bits are 1, the value is Inf or NaN.
    // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
    // mantissa bit. Quiet NaN is indicated by the most significant mantissa
    // bit being 1. Signaling NaN is indicated by the most significant
    // mantissa bit being 0 but some other bit(s) being 1. If any of the
    // lower 16 bits of the mantissa are 1, we set the least significant bit
    // of the bfloat16 mantissa, in order to preserve signaling NaN in case
    // the bloat16's mantissa bits are all 0.
    u.u32 |= 0x10000;  // Preserve signaling NaN
  }

  ret.data = (u.u32 >> 16);
  return ret;
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Convert double to __nv_bfloat16
 */
__HOST_DEVICE__ __nv_bfloat16 __double2bfloat16(const double a) {
  return __float2bfloat16((float)a);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Convert float2 to __nv_bfloat162
 */
__HOST_DEVICE__ __nv_bfloat162 __float22bfloat162_rn(const float2 a) {
  return __nv_bfloat162{__float2bfloat16(a.x), __float2bfloat16(a.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Converts low 16 bits of __nv_bfloat162 to float and returns the result
 */
__HOST_DEVICE__ float __low2float(const __nv_bfloat162 a) { return __bfloat162float(a.x); }

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Converts high 16 bits of __nv_bfloat162 to float and returns the result
 */
__HOST_DEVICE__ float __high2float(const __nv_bfloat162 a) { return __bfloat162float(a.y); }

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Converts and moves bfloat162 to float2
 */
__HOST_DEVICE__ float2 __bfloat1622float2(const __nv_bfloat162 a) {
  return float2{__bfloat162float(a.x), __bfloat162float(a.y)};
}
}

#if defined(__cplusplus) && defined(__CUDACC__)
//#include "crt/device_functions_internal.h" ocml conversion functions  __clang_cuda include
//#include "crt/math_fwd_internal.h" ocml device functions __clang_cuda include

namespace{
/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Moves bfloat16 value to bfloat162
 */
__device__ __nv_bfloat162 __bfloat162bfloat162(const __nv_bfloat16 a) {
  return __nv_bfloat162{a, a};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Reinterprets bits in a __nv_bfloat16 as a signed short integer
 */
__device__ short int __bfloat16_as_short(const __nv_bfloat16 h) { return (short)h.data; }

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Reinterprets bits in a __nv_bfloat16 as an unsigned signed short integer
 */
__device__ unsigned short int __bfloat16_as_ushort(const __nv_bfloat16 h) { return h.data; }

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Combine two __nv_bfloat16 to __nv_bfloat162
 */
__device__ __nv_bfloat162 __halves2bfloat162(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return __nv_bfloat162{a, b};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Returns high 16 bits of __nv_bfloat162
 */
__device__ __nv_bfloat16 __high2bfloat16(const __nv_bfloat162 a) { return a.y; }

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Returns high 16 bits of __nv_bfloat162
 */
__device__ __nv_bfloat162 __high2bfloat162(const __nv_bfloat162 a) {
  return __nv_bfloat162{a.y, a.y};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Extracts high 16 bits from each and combines them
 */
__device__ __nv_bfloat162 __highs2bfloat162(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __nv_bfloat162{a.y, b.y};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Returns low 16 bits of __nv_bfloat162
 */
__device__ __nv_bfloat16 __low2bfloat16(const __nv_bfloat162 a) { return a.x; }

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Returns low 16 bits of __nv_bfloat162
 */
__device__ __nv_bfloat162 __low2bfloat162(const __nv_bfloat162 a) {
  return __nv_bfloat162{a.x, a.x};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Swaps both halves
 */
__device__ __nv_bfloat162 __lowhigh2highlow(const __nv_bfloat162 a) {
  return __nv_bfloat162{a.y, a.x};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Extracts low 16 bits from each and combines them
 */
__device__ __nv_bfloat162 __lows2bfloat162(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __nv_bfloat162{a.x, b.x};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Reinterprets short int into a bfloat16
 */
__device__ __nv_bfloat16 __short_as_bfloat16(const short int a) {
  return __nv_bfloat16{(unsigned short)a};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_CONV
 * \brief Reinterprets unsigned short int into a bfloat16
 */
__device__ __nv_bfloat16 __ushort_as_bfloat16(const unsigned short int a) {
  return __nv_bfloat16{a};
}


/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_ARITH
 * \brief Adds two bfloat16 values
 */
__device__ __nv_bfloat16 __hadd(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_ARITH
 * \brief Subtracts two bfloat16 values
 */
__device__ __nv_bfloat16 __hsub(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return __float2bfloat16(__bfloat162float(a) - __bfloat162float(b));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_ARITH
 * \brief Divides two bfloat16 values
 */
__device__ __nv_bfloat16 __hdiv(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return __float2bfloat16(__bfloat162float(a) / __bfloat162float(b));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_ARITH
 * \brief Performs FMA of given bfloat16 values
 */
__device__ __nv_bfloat16 __hfma(const __nv_bfloat16 a, const __nv_bfloat16 b,
                                 const __nv_bfloat16 c) {
  return __float2bfloat16(
      __ocml_fma_f32(__bfloat162float(a), __bfloat162float(b), __bfloat162float(c)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_ARITH
 * \brief Multiplies two bfloat16 values
 */
__device__ __nv_bfloat16 __hmul(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_ARITH
 * \brief Negate a bfloat16 value
 */
__device__ __nv_bfloat16 __hneg(const __nv_bfloat16 a) {
  auto ret = a;
  ret.data ^= 0x8000;
  return ret;
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_ARITH
 * \brief Returns absolute of a bfloat16
 */
__device__ __nv_bfloat16 __habs(const __nv_bfloat16 a) {
  auto ret = a;
  ret.data &= 0x7FFF;
  return ret;
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_ARITH
 * \brief Divides bfloat162 values
 */
__device__ __nv_bfloat162 __h2div(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __nv_bfloat162{__float2bfloat16(__bfloat162float(a.x) / __bfloat162float(b.x)),
                         __float2bfloat16(__bfloat162float(a.y) / __bfloat162float(b.y))};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_ARITH
 * \brief Returns absolute of a bfloat162
 */
__device__ __nv_bfloat162 __habs2(const __nv_bfloat162 a) {
  return __nv_bfloat162{__habs(a.x), __habs(a.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_ARITH
 * \brief Adds two bfloat162 values
 */
__device__ __nv_bfloat162 __hadd2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __nv_bfloat162{__hadd(a.x, b.x), __hadd(a.y, b.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_ARITH
 * \brief Performs FMA of given bfloat162 values
 */
__device__ __nv_bfloat162 __hfma2(const __nv_bfloat162 a, const __nv_bfloat162 b,
                                   const __nv_bfloat162 c) {
  return __nv_bfloat162{__hfma(a.x, b.x, c.x), __hfma(a.y, b.y, c.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_ARITH
 * \brief Multiplies two bfloat162 values
 */
__device__ __nv_bfloat162 __hmul2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __nv_bfloat162{__hmul(a.x, b.x), __hmul(a.y, b.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_ARITH
 * \brief Converts a bfloat162 into negative
 */
__device__ __nv_bfloat162 __hneg2(const __nv_bfloat162 a) {
  return __nv_bfloat162{__hneg(a.x), __hneg(a.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_ARITH
 * \brief Subtracts two bfloat162 values
 */
__device__ __nv_bfloat162 __hsub2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __nv_bfloat162{__hsub(a.x, b.x), __hsub(a.y, b.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_COMP
 * \brief Compare two bfloat162 values
 */
__device__ bool __heq(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return __bfloat162float(a) == __bfloat162float(b);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_COMP
 * \brief Compare two bfloat162 values - unordered equal
 */
__device__ bool __hequ(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return !(__bfloat162float(a) < __bfloat162float(b)) &&
      !(__bfloat162float(a) > __bfloat162float(b));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_COMP
 * \brief Compare two bfloat162 values - greater than
 */
__device__ bool __hgt(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return __bfloat162float(a) > __bfloat162float(b);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_COMP
 * \brief Compare two bfloat162 values - unordered greater than
 */
__device__ bool __hgtu(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return !(__bfloat162float(a) <= __bfloat162float(b));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_COMP
 * \brief Compare two bfloat162 values - greater than equal
 */
__device__ bool __hge(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return __bfloat162float(a) >= __bfloat162float(b);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_COMP
 * \brief Compare two bfloat162 values - unordered greater than equal
 */
__device__ bool __hgeu(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return !(__bfloat162float(a) < __bfloat162float(b));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_COMP
 * \brief Compare two bfloat162 values - not equal
 */
__device__ bool __hne(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return __bfloat162float(a) != __bfloat162float(b);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_COMP
 * \brief Compare two bfloat162 values - unordered not equal
 */
__device__ bool __hneu(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return !(__bfloat162float(a) == __bfloat162float(b));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_COMP
 * \brief Compare two bfloat162 values - return max
 */
__device__ __nv_bfloat16 __hmax(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a), __bfloat162float(b)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_COMP
 * \brief Compare two bfloat162 values - return min
 */
__device__ __nv_bfloat16 __hmin(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a), __bfloat162float(b)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_COMP
 * \brief Compare two bfloat162 values - less than operator
 */
__device__ bool __hlt(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return __bfloat162float(a) < __bfloat162float(b);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_COMP
 * \brief Compare two bfloat162 values - unordered less than
 */
__device__ bool __hltu(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return !(__bfloat162float(a) >= __bfloat162float(b));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_COMP
 * \brief Compare two bfloat162 values - less than
 */
__device__ bool __hle(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return __bfloat162float(a) <= __bfloat162float(b);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_COMP
 * \brief Compare two bfloat162 values - unordered less than equal
 */
__device__ bool __hleu(const __nv_bfloat16 a, const __nv_bfloat16 b) {
  return !(__bfloat162float(a) > __bfloat162float(b));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_COMP
 * \brief Checks if number is inf
 */
__device__ int __hisinf(const __nv_bfloat16 a) { return __ocml_isinf_f32(__bfloat162float(a)); }

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_COMP
 * \brief Checks if number is nan
 */
__device__ bool __hisnan(const __nv_bfloat16 a) { return __ocml_isnan_f32(__bfloat162float(a)); }

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Checks if two numbers are equal
 */
__device__ bool __hbeq2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __heq(a.x, b.x) && __heq(a.y, b.y);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Checks if two numbers are equal - unordered
 */
__device__ bool __hbequ2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __hequ(a.x, b.x) && __hequ(a.y, b.y);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Check for a >= b
 */
__device__ bool __hbge2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __hge(a.x, b.x) && __hge(a.y, b.y);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Check for a >= b - unordered
 */
__device__ bool __hbgeu2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __hgeu(a.x, b.x) && __hgeu(a.y, b.y);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Check for a > b
 */
__device__ bool __hbgt2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __hgt(a.x, b.x) && __hgt(a.y, b.y);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Check for a > b - unordered
 */
__device__ bool __hbgtu2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __hgtu(a.x, b.x) && __hgtu(a.y, b.y);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Check for a <= b
 */
__device__ bool __hble2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __hle(a.x, b.x) && __hle(a.y, b.y);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Check for a <= b - unordered
 */
__device__ bool __hbleu2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __hleu(a.x, b.x) && __hleu(a.y, b.y);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Check for a < b
 */
__device__ bool __hblt2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __hlt(a.x, b.x) && __hlt(a.y, b.y);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Check for a < b - unordered
 */
__device__ bool __hbltu2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __hltu(a.x, b.x) && __hltu(a.y, b.y);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Check for a != b
 */
__device__ bool __hbne2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __hne(a.x, b.x) && __hne(a.y, b.y);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Check for a != b
 */
__device__ bool __hbneu2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __hneu(a.x, b.x) && __hneu(a.y, b.y);
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Check for a != b, returns 1.0 if equal, otherwise 0.0
 */
__device__ __nv_bfloat162 __heq2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __nv_bfloat162{{__heq(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)},
                         {__heq(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Check for a >= b, returns 1.0 if greater than equal, otherwise 0.0
 */
__device__ __nv_bfloat162 __hge2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __nv_bfloat162{{__hge(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)},
                         {__hge(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Check for a > b, returns 1.0 if greater than equal, otherwise 0.0
 */
__device__ __nv_bfloat162 __hgt2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __nv_bfloat162{{__hgt(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)},
                         {__hgt(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Check for a is NaN, returns 1.0 if NaN, otherwise 0.0
 */
__device__ __nv_bfloat162 __hisnan2(const __nv_bfloat162 a) {
  return __nv_bfloat162{
      {__ocml_isnan_f32(__bfloat162float(a.x)) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)},
      {__ocml_isnan_f32(__bfloat162float(a.y)) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Check for a <= b, returns 1.0 if greater than equal, otherwise 0.0
 */
__device__ __nv_bfloat162 __hle2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __nv_bfloat162{{__hle(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)},
                         {__hle(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Check for a < b, returns 1.0 if greater than equal, otherwise 0.0
 */
__device__ __nv_bfloat162 __hlt2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __nv_bfloat162{{__hlt(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)},
                         {__hlt(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Returns max of two elements
 */
__device__ __nv_bfloat162 __hmax2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __nv_bfloat162{
      __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a.x), __bfloat162float(b.x))),
      __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a.y), __bfloat162float(b.y)))};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Returns min of two elements
 */
__device__ __nv_bfloat162 __hmin2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __nv_bfloat162{
      __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a.x), __bfloat162float(b.x))),
      __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a.y), __bfloat162float(b.y)))};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_COMP
 * \brief Checks for not equal to
 */
__device__ __nv_bfloat162 __hne2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  return __nv_bfloat162{{__hne(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)},
                         {__hne(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_MATH
 * \brief Calculate ceil of bfloat16
 */
__device__ __nv_bfloat16 hceil(const __nv_bfloat16 h) {
  return __float2bfloat16(__ocml_ceil_f32(__bfloat162float(h)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_MATH
 * \brief Calculate cosine of bfloat16
 */
__device__ __nv_bfloat16 hcos(const __nv_bfloat16 h) {
  return __float2bfloat16(__ocml_cos_f32(__bfloat162float(h)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_MATH
 * \brief Calculate exponential of bfloat16
 */
__device__ __nv_bfloat16 hexp(const __nv_bfloat16 h) {
  return __float2bfloat16(__ocml_exp_f32(__bfloat162float(h)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_MATH
 * \brief Calculate exponential 10 of bfloat16
 */
__device__ __nv_bfloat16 hexp10(const __nv_bfloat16 h) {
  return __float2bfloat16(__ocml_exp10_f32(__bfloat162float(h)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_MATH
 * \brief Calculate exponential 2 of bfloat16
 */
__device__ __nv_bfloat16 hexp2(const __nv_bfloat16 h) {
  return __float2bfloat16(__ocml_exp2_f32(__bfloat162float(h)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_MATH
 * \brief Calculate floor of bfloat16
 */
__device__ __nv_bfloat16 hfloor(const __nv_bfloat16 h) {
  return __float2bfloat16(__ocml_floor_f32(__bfloat162float(h)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_MATH
 * \brief Calculate natural log of bfloat16
 */
__device__ __nv_bfloat16 hlog(const __nv_bfloat16 h) {
  return __float2bfloat16(__ocml_log_f32(__bfloat162float(h)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_MATH
 * \brief Calculate log 10 of bfloat16
 */
__device__ __nv_bfloat16 hlog10(const __nv_bfloat16 h) {
  return __float2bfloat16(__ocml_log10_f32(__bfloat162float(h)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_MATH
 * \brief Calculate log 2 of bfloat16
 */
__device__ __nv_bfloat16 hlog2(const __nv_bfloat16 h) {
  return __float2bfloat16(__ocml_log2_f32(__bfloat162float(h)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_MATH
 * \brief Calculate reciprocal
 */
__device__ __nv_bfloat16 hrcp(const __nv_bfloat16 h) {
  return __float2bfloat16(1.0f / (__bfloat162float(h)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_MATH
 * \brief Round to nearest int
 */
__device__ __nv_bfloat16 hrint(const __nv_bfloat16 h) {
  return __float2bfloat16(__ocml_rint_f32(__bfloat162float(h)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_MATH
 * \brief Reciprocal square root
 */
__device__ __nv_bfloat16 hrsqrt(const __nv_bfloat16 h) {
  return __float2bfloat16(__ocml_rsqrt_f32(__bfloat162float(h)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_MATH
 * \brief Calculate sin of bfloat16
 */
__device__ __nv_bfloat16 hsin(const __nv_bfloat16 h) {
  return __float2bfloat16(__ocml_sin_f32(__bfloat162float(h)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_MATH
 * \brief Calculate sqrt of bfloat16
 */
__device__ __nv_bfloat16 hsqrt(const __nv_bfloat16 h) {
  return __float2bfloat16(__ocml_sqrt_f32(__bfloat162float(h)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT16_MATH
 * \brief Calculate truncate of bfloat16
 */
__device__ __nv_bfloat16 htrunc(const __nv_bfloat16 h) {
  return __float2bfloat16(__ocml_trunc_f32(__bfloat162float(h)));
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_MATH
 * \brief Calculate ceil of bfloat162
 */
__device__ __nv_bfloat162 h2ceil(const __nv_bfloat162 h) {
  return __nv_bfloat162{hceil(h.x), hceil(h.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_MATH
 * \brief Calculate cosine of bfloat162
 */
__device__ __nv_bfloat162 h2cos(const __nv_bfloat162 h) {
  return __nv_bfloat162{hcos(h.x), hcos(h.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_MATH
 * \brief Calculate exponential of bfloat162
 */
__device__ __nv_bfloat162 h2exp(const __nv_bfloat162 h) {
  return __nv_bfloat162{hexp(h.x), hexp(h.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_MATH
 * \brief Calculate exponential 10 of bfloat162
 */
__device__ __nv_bfloat162 h2exp10(const __nv_bfloat162 h) {
  return __nv_bfloat162{hexp10(h.x), hexp10(h.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_MATH
 * \brief Calculate exponential 2 of bfloat162
 */
__device__ __nv_bfloat162 h2exp2(const __nv_bfloat162 h) {
  return __nv_bfloat162{hexp2(h.x), hexp2(h.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_MATH
 * \brief Calculate floor of bfloat162
 */
__device__ __nv_bfloat162 h2floor(const __nv_bfloat162 h) {
  return __nv_bfloat162{hfloor(h.x), hfloor(h.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_MATH
 * \brief Calculate natural log of bfloat162
 */
__device__ __nv_bfloat162 h2log(const __nv_bfloat162 h) {
  return __nv_bfloat162{hlog(h.x), hlog(h.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_MATH
 * \brief Calculate log 10 of bfloat162
 */
__device__ __nv_bfloat162 h2log10(const __nv_bfloat162 h) {
  return __nv_bfloat162{hlog10(h.x), hlog10(h.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_MATH
 * \brief Calculate log 2 of bfloat162
 */
__device__ __nv_bfloat162 h2log2(const __nv_bfloat162 h) {
  return __nv_bfloat162{hlog2(h.x), hlog2(h.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_MATH
 * \brief Calculate vector reciprocal
 */
__device__ __nv_bfloat162 h2rcp(const __nv_bfloat162 h) {
  return __nv_bfloat162{hrcp(h.x), hrcp(h.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_MATH
 * \brief Calculate vector round to nearest int
 */
__device__ __nv_bfloat162 h2rint(const __nv_bfloat162 h) {
  return __nv_bfloat162{hrint(h.x), hrint(h.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_MATH
 * \brief Calculate vector reciprocal square root
 */
__device__ __nv_bfloat162 h2rsqrt(const __nv_bfloat162 h) {
  return __nv_bfloat162{hrsqrt(h.x), hrsqrt(h.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_MATH
 * \brief Calculate sin of bfloat162
 */
__device__ __nv_bfloat162 h2sin(const __nv_bfloat162 h) {
  return __nv_bfloat162{hsin(h.x), hsin(h.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_MATH
 * \brief Calculate sqrt of bfloat162
 */
__device__ __nv_bfloat162 h2sqrt(const __nv_bfloat162 h) {
  return __nv_bfloat162{hsqrt(h.x), hsqrt(h.y)};
}

/**
 * \ingroup CUDA_INTRINSIC_BFLOAT162_MATH
 * \brief Calculate truncate of bfloat162
 */
__device__ __nv_bfloat162 h2trunc(const __nv_bfloat162 h) {
  return __nv_bfloat162{htrunc(h.x), htrunc(h.y)};
}
}
#endif // defined(__cplusplus) && defined(__CUDACC__)
#endif // __BF16_INTERNAL_H__


