Commit c6ec6638 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into auto_contig_fix

parents b42c7b41 a6d1540f
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
#include <migraphx/config.hpp>
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <typename To, typename From>
inline constexpr To bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
#if defined(__GNUC__) and !defined(__clang__)
return MIGRAPHX_CONST_FOLD(*reinterpret_cast<To*>(&fr));
#else
return __builtin_bit_cast(To, fr);
#endif
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
#endif // MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
......@@ -38,15 +38,12 @@ struct dynamic_loader_impl;
struct MIGRAPHX_EXPORT dynamic_loader
{
#ifndef _WIN32
template <class T>
static fs::path path(T* address)
{
return path(reinterpret_cast<void*>(address));
}
static fs::path path(void* address);
#endif
static optional<dynamic_loader> try_load(const fs::path& p);
dynamic_loader() = default;
......
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
// We are clipping/saturation in down conversion by default. Unclipped version is not tested and
// shouldn't be used without having enough tests.
// logic is based on clipping table from here : https://onnx.ai/onnx/technical/float8.html#cast
// NOLINTNEXTLINE
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
#include <cmath>
#include <cstdint>
#include <climits>
#include <cstring>
#include <iosfwd>
#include <limits>
#include <sstream>
#include <iostream>
#include <string>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/float8_impl.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace fp8 {
enum class rounding_mode
{
standard, // standard rounding is doing RNE -- round to nearest even
stochastic
};
enum class f8_type
{
bf8 = 0, // s1e5m2
fp8 = 1 // s1e4m3
};
template <typename T, bool FNUZ = true>
class numeric_limits;
template <migraphx::fp8::f8_type T = migraphx::fp8::f8_type::fp8, bool FNUZ = true>
struct float8
{
uint8_t data = 0x00;
// default constructor
constexpr float8() = default;
// default copy constructor
constexpr float8(const float8& y) = default;
struct from_bits_t
{
};
static constexpr from_bits_t from_bits() { return from_bits_t(); }
explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {}
explicit constexpr float8(
float v,
migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard,
uint32_t rng = 0)
{
if constexpr(T == migraphx::fp8::f8_type::fp8)
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
}
else
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#endif // rocblas_F8_downcast_clipping}
}
}
inline constexpr operator float() const
{
if constexpr(T == migraphx::fp8::f8_type::fp8)
{
return migraphx::fp8::impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(data);
} // else
return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data);
}
inline constexpr bool is_zero() const
{
if constexpr(FNUZ)
{
return data == 0x00;
}
else
{
return (data == 0x00) or (data == 0x80);
}
}
inline constexpr bool is_nan() const
{
if constexpr(FNUZ)
{
return data == 0x80;
}
else
{
if(T == migraphx::fp8::f8_type::bf8)
{
return (data == 0x7D) or (data == 0x7E) or (data == 0x7F) or (data == 0xFD) or
(data == 0xFE) or (data == 0xFF);
}
else
{
return (data == 0x7F) or (data == 0xFF);
}
}
}
inline constexpr bool is_inf() const
{
if constexpr(FNUZ)
{
return data == 0x80;
}
else
{
if(T == migraphx::fp8::f8_type::bf8)
{
return (data == 0x7C) or (data == 0xFC);
}
else
{
// no infinities in e4m3fn, represent them as NaNs
return (data == 0x7F) or (data == 0xFF);
}
}
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \
constexpr float8& operator unary_op(const float8& rhs) \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \
return *this; \
} \
constexpr float8& operator unary_op(const float& rhs) \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \
return *this; \
}
MIGRAPHX_FP8_UNARY_OP(*=, *)
MIGRAPHX_FP8_UNARY_OP(-=, -)
MIGRAPHX_FP8_UNARY_OP(+=, +)
MIGRAPHX_FP8_UNARY_OP(/=, /)
inline constexpr float8& operator=(const float8& rhs) = default;
inline constexpr float8& operator=(float8&& rhs) noexcept = default;
inline constexpr float8& operator=(float rhs)
{
*this = static_cast<float8>(rhs);
return *this;
}
inline constexpr bool operator==(const float8& rhs) const
{
if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf())
return false;
else if((rhs.is_zero() and this->is_zero()) or (this->data == rhs.data))
return true;
return false;
}
inline constexpr bool operator<(const float8& rhs) const
{
const auto we = static_cast<float>(*this);
const auto them = static_cast<float>(rhs);
return we < them;
}
inline constexpr bool operator>(const float8& rhs) const
{
const auto we = static_cast<float>(*this);
const auto them = static_cast<float>(rhs);
return we > them;
}
};
// https://onnx.ai/onnx/technical/float8.html
using fp8e4m3fn = float8<migraphx::fp8::f8_type::fp8, false>;
using fp8e5m2 = float8<migraphx::fp8::f8_type::bf8, false>;
using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>;
using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
/*
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \
inline constexpr U operator binary_op(const T& lhs, const T& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
// TODO: these should return floats for binary ops
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP_GEN_FOR(T) \
MIGRAPHX_FP8_BINARY_OP(*, T, T) \
MIGRAPHX_FP8_BINARY_OP(-, T, T) \
MIGRAPHX_FP8_BINARY_OP(/, T, T) \
MIGRAPHX_FP8_BINARY_OP(+, T, T) \
MIGRAPHX_FP8_BINARY_OP(==, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<, T, bool) \
MIGRAPHX_FP8_BINARY_OP(!=, T, bool)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e5m2)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e4m3fn)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e5m2fnuz)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e4m3fnuz)
*/
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const fp8e4m3fnuz& rhs)
{
return os << static_cast<float>(rhs);
}
inline fp8e4m3fnuz fabs(fp8e4m3fnuz v)
{
v.data = v.data & 0x7F; // NOLINT
return v;
}
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const fp8e4m3fn& rhs)
{
return os << static_cast<float>(rhs);
}
inline fp8e4m3fn fabs(fp8e4m3fn v)
{
v.data = v.data & 0x7F; // NOLINT
return v;
}
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const fp8e5m2fnuz& rhs)
{
return os << static_cast<float>(rhs);
}
inline fp8e5m2fnuz fabs(fp8e5m2fnuz v)
{
v.data = v.data & 0x7F; // NOLINT
return v;
}
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const fp8e5m2& rhs)
{
return os << static_cast<float>(rhs);
}
inline fp8e5m2 fabs(fp8e5m2 v)
{
v.data = v.data & 0x7F; // NOLINT
return v;
}
template <>
class numeric_limits<fp8e4m3fnuz>
{
public:
static constexpr bool has_infinity = false;
static constexpr fp8e4m3fnuz epsilon() { return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); }
// NOLINTNEXTLINE
static constexpr fp8e4m3fnuz quiet_NaN() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); }
static constexpr fp8e4m3fnuz max() { return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01
static constexpr fp8e4m3fnuz min() { return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits()); }
static constexpr fp8e4m3fnuz lowest() { return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits()); }
};
template <>
class numeric_limits<fp8e4m3fn>
{
public:
static constexpr bool has_infinity = false;
static constexpr fp8e4m3fn epsilon() { return fp8e4m3fn(0x20, fp8e4m3fn::from_bits()); }
// NOLINTNEXTLINE
static constexpr fp8e4m3fn quiet_NaN() { return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits()); }
static constexpr fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01
static constexpr fp8e4m3fn min() { return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); }
static constexpr fp8e4m3fn lowest() { return fp8e4m3fn(0xFE, fp8e4m3fn::from_bits()); }
};
template <>
class numeric_limits<fp8e5m2fnuz>
{
public:
static constexpr bool has_infinity = false;
static constexpr fp8e5m2fnuz epsilon() { return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); }
static constexpr fp8e5m2fnuz quiet_NaN() // NOLINT
{
return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits());
}
static constexpr fp8e5m2fnuz max() { return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
static constexpr fp8e5m2fnuz min() { return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); }
static constexpr fp8e5m2fnuz lowest() { return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits()); }
};
template <>
class numeric_limits<fp8e5m2>
{
public:
static constexpr bool has_infinity = true;
static constexpr fp8e5m2 epsilon() { return fp8e5m2(0x34, fp8e5m2::from_bits()); }
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
static constexpr fp8e5m2 quiet_NaN() { return fp8e5m2(0xFF, fp8e5m2::from_bits()); } // NOLINT
static constexpr fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
static constexpr fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); }
static constexpr fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); }
// 7C and FC both are infinity
static constexpr fp8e5m2 infinity() { return fp8e5m2(0x7C, fp8e5m2::from_bits()); }
};
} // namespace fp8
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
// =================================================================================================
// define numeric limits for the new data type
// NOLINTBEGIN
namespace std {
#define MIGRAPHX_FP8_STD_OVERLOADS(T) \
inline bool isfinite(T x) { return not x.is_inf() and not x.is_nan(); } \
inline bool isnan(T x) { return x.is_nan(); } \
template <> \
class numeric_limits<T> : public migraphx::fp8::numeric_limits<T> \
{ \
}; \
template <class U> \
struct common_type<T, U> : std::common_type<float, U> \
{ \
}; \
template <class U> \
struct common_type<U, T> : std::common_type<float, U> \
{ \
}; \
template <> \
struct common_type<T, T> \
{ \
using type = T; \
};
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fn)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2fnuz)
} // namespace std
// NOLINTEND
// =================================================================================================
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP
#include <algorithm>
#include <cstdint>
#include <type_traits>
#include <migraphx/config.hpp>
#include <migraphx/bit_cast.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace fp8 {
namespace impl {
// NOLINTBEGIN
template <uint32_t Wm, uint32_t We, typename T, bool NegativeZeroNan, bool Clip>
constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0)
{
constexpr bool is_float = std::is_same<T, float>::value;
// half is not supported for now
constexpr bool is_half = false;
static_assert(Wm + We == 7, "Wm+We==7");
static_assert(is_float or is_half, "Only float can be cast to f8");
const uint32_t mfmt = (sizeof(T) == 4) ? 23 : 10;
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x;
if constexpr(sizeof(T) == 4)
x = migraphx::bit_cast<uint32_t>(f_x);
else
x = migraphx::bit_cast<uint16_t>(f_x);
uint32_t head = 0;
uint32_t mantissa = 0;
int exponent = 0;
uint32_t bias = 0;
uint32_t sign = 0;
if constexpr(sizeof(T) == 4)
{
head = x & 0xFF800000;
mantissa = x & 0x7FFFFF;
exponent = (head >> 23) & 0xFF;
sign = head >> 31;
bias = 127;
}
else
{
head = x & 0xFC00;
mantissa = x & 0x3FF;
exponent = (head >> 10) & 0x1F;
sign = head >> 15;
bias = 15;
}
uint32_t signed_inf = (sign << 7) + (((1 << We) - 1) << Wm);
uint32_t signed_all_ones = (sign << 7) + ((((1 << We) - 1) << Wm) + ((1 << Wm) - 1));
// Calcualte maximum singed value FLT_MAX, FLT_MIN
uint32_t signed_max = signed_all_ones;
if(not NegativeZeroNan)
signed_max = (Wm == 2) ? (signed_max - 4) : (signed_max - 1);
// Deal with inf and NaNs
if(NegativeZeroNan) // For the FNUZ cases, it is simple just return NaNs
{
if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or
(sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00)))
return 0x80;
}
else
{
// calculate most common NaN mantissa for FP8, which is all Ones in binary
uint32_t nan_mantissa = 1;
for(auto i = 1; i < Wm; ++i)
{
nan_mantissa |= (nan_mantissa << 1);
}
if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or
(sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00)))
{
// infinity
if(mantissa == 0)
{
if(sign == 0)
return (Wm == 2) ? 0x7B : 0x7E;
else
return (Wm == 2) ? 0xFB : 0xFE;
}
else // NaNs
return signed_inf + nan_mantissa;
}
}
// handle positive zero
if(x == 0)
return 0;
// handle negative zero
else if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000))
{
return NegativeZeroNan ? 0 : 0x80; // For FNUZ types neg zero is just positive zero
}
/* First need to check if it is normal or denorm as there is a difference of implict 1
Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
RNE, no need to add rng. Then probably need to check whether there is carry and adjust
exponent and mantissa again*/
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
const int f8_bias = (1 << (We - 1u)) - 1 + (NegativeZeroNan ? 1 : 0);
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
/* act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
f8_exponent is the converted f8 exponent with bias encoding
exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
the difference needs to be adjusted and mantissa shifted*/
int act_exponent = 0;
int f8_exponent = 0;
int exponent_diff = 0;
if(exponent == 0 and mantissa != 0)
{ // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal
has exponent bias 15 while bf8 with FNUZ has exponent bias 16. It means that there are some
numbers in fp16 denormal but they are bf8 (FNUZ) normals - smallest bf8 (FNUZ) normal is
2^-15. fp16 numbers where exponent==0 (actual exponent -14) and highest bit of mantissa is 1
are bf8 (FNUZ) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = 1 - bias;
exponent_diff = f8_denormal_act_exponent -
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
}
else
{ // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias;
if(act_exponent <= f8_denormal_act_exponent)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 FNUZ mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 FNUZ */
exponent_diff = f8_denormal_act_exponent - act_exponent;
}
else
{ // both fp32/fp16 and f8 are in normal range
exponent_diff =
0; // exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa += (1u << mfmt); // Add the implicit 1 into mantissa
}
// need to know whether the number is right in the middle of two adjacent fp8 numbers. use max
// value of 31 to avoid undefined behaviour
bool midpoint = (mantissa & ((1u << std::min(31u, mfmt - Wm + exponent_diff)) - 1)) ==
(1u << std::min(31u, mfmt - Wm + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint.
*/
if(exponent_diff > 0)
mantissa >>= std::min(31u, uint32_t(exponent_diff));
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << mfmt);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
f8_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
uint32_t drop_mask = (1u << (mfmt - Wm)) - 1;
bool odd =
mantissa & (1u << (mfmt - Wm)); // if the least significant bit that is not truncated is 1
/*
This part is doing rounding by adding mantissa part that is going to get dropped.
e.g. if the dropped part for less than 0.5 than it would round down.
if the dropped part is more than 0.5 then it would round up by rolling carry to LSB of retained
mantissa.
For the mid point when bit pattern is like this for Odd: `xy1:10000000` for Odd and
`xy0:10000000` for the Even. where `:` is delimiter for dropped v/s retained part.
For the odd case :
this will add xy1:10000000 + 000:10000000 which would roll over carry to LSB of retained
part making it RNE.
For the even case : this will add xy0:10000000 + 000:01111111 which would
round down and keep number Even
*/
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
// Now we deal with overflow
if(f8_exponent == 0 and ((1 << mfmt) & mantissa))
{
f8_exponent = 1; // denormal overflow to become normal, promote exponent
}
else if((1 << (mfmt + 1)) & mantissa)
{
mantissa >>= 1;
f8_exponent++;
}
mantissa >>= (mfmt - Wm);
// above range: quantize to maximum possible float of the same sign
// for e5m2 case, max_exp is 14, since exp = 15 is reserved for Infs and Nans
const int max_exp = (1 << We) - ((NegativeZeroNan or Wm == 3) ? 1 : 2);
if(f8_exponent > max_exp)
{
if(Clip)
return signed_max;
else
{
// https://onnx.ai/onnx/technical/float8.html#cast
if(NegativeZeroNan)
return 0x80;
else
return (Wm == 2) ? signed_inf : signed_all_ones;
}
}
if(f8_exponent == 0 and mantissa == 0)
return NegativeZeroNan ? 0 : (sign << 7);
mantissa &= (1 << Wm) - 1;
return (sign << 7) | (f8_exponent << Wm) | mantissa;
}
// NOLINTEND
template <uint32_t Wm, uint32_t We, typename T, bool NegativeZeroNan>
constexpr T cast_from_f8(uint8_t x)
{
// half is not supported for now
constexpr bool is_half = false;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_float or is_half, "Only float are supported");
constexpr int weo = is_half ? 5 : 8;
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
// NOLINTNEXTLINE
T f_inf, f_neg_inf, f_nan, f_neg0;
if constexpr(is_float)
{
const uint32_t if_inf = 0x7F800000;
const uint32_t if_neg_inf = 0xFF800000;
const uint32_t if_nan = 0x7F800001;
const uint32_t if_neg0 = 0x80000000;
f_inf = migraphx::bit_cast<float>(if_inf);
f_neg_inf = migraphx::bit_cast<float>(if_neg_inf);
f_nan = migraphx::bit_cast<float>(if_nan);
f_neg0 = migraphx::bit_cast<float>(if_neg0);
}
if(x == 0)
return 0;
uint32_t sign = x >> 7; // NOLINT
uint32_t mantissa = x & ((1 << Wm) - 1); // NOLINT
int exponent = (x & 0x7F) >> Wm; // NOLINT
if(NegativeZeroNan)
{
if(x == 0x80)
return f_nan;
}
else
{
if(x == 0x80)
return f_neg0;
if(exponent == ((1 << We) - 1) and Wm == 2) // NOLINT
return (mantissa == 0) ? (sign ? f_neg_inf : f_inf) : f_nan;
else if(Wm == 3 and (x == 0x7F or x == 0xFF))
return f_nan;
}
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
const int exp_low_cutoff =
(1 << (weo - 1)) - (1 << (We - 1)) + 1 - (NegativeZeroNan ? 1 : 0); // NOLINT
// subnormal input
if(exponent == 0)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __builtin_clz(mantissa) - (32 - Wm);
mantissa <<= sh; // NOLINT
exponent += 1 - sh;
mantissa &= ((1 << Wm) - 1); // NOLINT
}
exponent += exp_low_cutoff - 1;
mantissa <<= wmo - Wm; // NOLINT
// subnormal output (occurs when T=half, We=5, negative_zero_nan=true)
if(exponent <= 0)
{
mantissa |= 1 << wmo; // NOLINT
mantissa >>= 1 - exponent; // NOLINT
exponent = 0;
}
if(sizeof(T) == 2)
retval = (sign << 15) | (exponent << 10) | mantissa; // NOLINT
else
retval = (sign << 31) | (exponent << 23) | mantissa; // NOLINT
return migraphx::bit_cast<T>(retval);
}
} // namespace impl
} // namespace fp8
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL
......@@ -27,6 +27,7 @@
#include <half/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/float8.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -67,6 +68,18 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
{
};
template <>
struct common_type<migraphx::fp8::fp8e4m3fnuz, migraphx::half>
{
using type = float;
};
template <>
struct common_type<migraphx::half, migraphx::fp8::fp8e4m3fnuz>
{
using type = float;
};
template <>
struct common_type<migraphx::half, migraphx::half>
{
......
......@@ -591,6 +591,19 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
}
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins)
{
if(ins->name() != "@literal")
return false;
bool all_same = false;
ins->get_literal().visit([&](auto s) {
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) {
return float_equal(scale, s.front());
});
});
return all_same;
}
MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins)
{
if(ins->outputs().size() == 1)
......@@ -844,6 +857,12 @@ auto skip_broadcasts_converts(Ms... ms)
return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...);
}
template <class... Ms>
auto skip_broadcasts_transposes_contiguous(Ms... ms)
{
return skip(name("broadcast", "multibroadcast", "contiguous", "transpose"))(ms...);
}
template <class T>
inline auto has_value(T x, float tolerance = 1e-6)
{
......
......@@ -29,6 +29,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/par.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -95,11 +96,11 @@ struct binary : op_name<Derived>
{
argument result{dyn_out.computed_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
par_transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
});
return result;
}
......
......@@ -70,7 +70,8 @@ struct pooling
// 2 smaller than the input tensor rank (NCHW layout)
std::vector<std::size_t> lengths = {1, 1};
// Dilations are not supported at this time.
// Spacing between the elements of the pooling kernel. Must be the same ndim as lengths.
std::vector<std::size_t> dilations = {1, 1};
// ceiling mode is a flag affecting output size
// or equivalently, placements of the pooling kernel.
......@@ -99,6 +100,7 @@ struct pooling
f(self.padding_mode, "padding_mode"),
f(self.stride, "stride"),
f(self.lengths, "lengths"),
f(self.dilations, "dilations"),
f(self.ceil_mode, "ceil_mode"),
f(self.lp_order, "lp_order"),
f(self.dyn_global, "dyn_global"));
......@@ -112,14 +114,17 @@ struct pooling
return;
if((padding_mode != default_ and padding.size() != stride.size() and
(padding.size()) != stride.size() * 2) or
stride.size() != lengths.size())
stride.size() != lengths.size() or dilations.size() != lengths.size())
{
MIGRAPHX_THROW("POOLING: inconsistent attribute sizes");
}
if(std::any_of(lengths.begin(), lengths.end(), [&](auto i) { return (i == 0); }) or
std::any_of(stride.begin(), stride.end(), [&](auto i) { return (i == 0); }))
const auto is_zero = [](auto el) { return el == 0; };
if(std::any_of(lengths.begin(), lengths.end(), is_zero) or
std::any_of(stride.begin(), stride.end(), is_zero) or
std::any_of(dilations.begin(), dilations.end(), is_zero))
{
MIGRAPHX_THROW("POOLING: size 0 pooling kernel or stride");
MIGRAPHX_THROW("POOLING: size 0 pooling kernel or stride or dilations");
}
// TODO: update lowering to run the reference
......@@ -142,6 +147,11 @@ struct pooling
value attributes() const { return {{"normalize_padding", "padding"}}; }
inline std::size_t dilate_dim(std::size_t dim, std::size_t dilation) const
{
return 1 + dilation * (dim - 1);
}
std::vector<std::size_t> calc_spatial_dim_out(const std::vector<std::size_t>& input_lens,
std::size_t kdims) const
{
......@@ -151,8 +161,9 @@ struct pooling
std::size_t padding_factor = 2 * padding[i];
if(padding.size() == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
std::size_t dilated_length = dilate_dim(lengths[i], dilations[i]);
std::size_t dim_size;
if(input_lens[i + 2] + padding_factor < lengths[i])
if(input_lens[i + 2] + padding_factor < dilated_length)
{
if(padding_mode == default_)
MIGRAPHX_THROW("POOLING: not enough padding for the given kernel size");
......@@ -162,7 +173,7 @@ struct pooling
}
else
{
dim_size = input_lens[i + 2] + padding_factor - lengths[i];
dim_size = input_lens[i + 2] + padding_factor - dilated_length;
}
std::size_t len =
(ceil_mode)
......@@ -331,6 +342,7 @@ struct pooling
int start = static_cast<int>(idx_o[dim] * stride[d_2]) -
static_cast<int>(padding_vals[d_2]);
int end;
std::size_t dilated_kernel_dim = dilate_dim(kernel_dims[d_2], dilations[d_2]);
// NOLINT
if(count_include_pad and ceil_mode and (mode != pooling_mode::max))
{
......@@ -340,15 +352,14 @@ struct pooling
// padding. Clip out-of-bounds indexes but not padding.
// Check if this kernel extends beyond the padding at end of dimension
end = std::min(start + kernel_dims[d_2],
end = std::min(start + dilated_kernel_dim,
in_lens[dim] + static_cast<int>(padding_vals[d_2]));
}
else
{
// In non-ceiling mode, when
// count_include_pad is false, or for max pooling, clip off padding.
end = std::min(start + kernel_dims[d_2], in_lens[dim]);
start = std::max(start, 0);
end = std::min(start + dilated_kernel_dim, in_lens[dim]);
}
win_start.push_back(start);
if(end < start)
......@@ -366,6 +377,16 @@ struct pooling
// for each element in the window...
shape_for_each(win_shape, [&](const auto& idx_w) {
// Skip elements that belong to the dilated area
for(size_t axis = 0; axis < idx_w.size(); ++axis)
{
if(idx_w[axis] % dilations[axis])
{
pool_size -= 1;
return;
}
}
// the coordinates of this element
auto idx = idx_o;
......@@ -390,7 +411,15 @@ struct pooling
// this is a padding element. Padding locations
// don't contribute to average or max pooling total but can play in
// lpnorm pooling.
output_val = op(output_val, 0);
if(mode == pooling_mode::lpnorm)
{
output_val = op(output_val, op.template init<Type>());
}
if(mode == pooling_mode::average)
{
// Ignore padding
pool_size -= 1;
}
}
});
output[i] = Type(op.final(output_val, pool_size));
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MAX_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_max : scatternd_op<scatternd_max>
{
scatternd_max() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x = std::max(x, y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MIN_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_min : scatternd_op<scatternd_min>
{
scatternd_min() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x = std::min(x, y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -121,7 +121,8 @@ struct scatternd_op : op_name<Derived>
auto k = indices_shape.lens().back();
auto q = indices_shape.ndim();
auto r = dyn_out.computed_shape.ndim();
par_for(updates_shape.elements(), [&](const auto i) {
for(auto i = 0u; i < updates_shape.elements(); ++i)
{
auto updates_idx = updates_std.multi(i);
std::vector<std::size_t> indices_idx(q, 0);
std::copy(
......@@ -135,7 +136,7 @@ struct scatternd_op : op_name<Derived>
std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
self.reduction()(output[dyn_out.computed_shape.index(out_idx)], updates[i]);
});
}
});
});
......
......@@ -31,6 +31,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/par.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -84,10 +85,10 @@ struct unary : op_name<Derived>
argument result{dyn_out.computed_shape};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
std::transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
par_transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
});
});
return result;
......
......@@ -119,6 +119,8 @@
#include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp>
#include <migraphx/op/scatternd_max.hpp>
#include <migraphx/op/scatternd_min.hpp>
#include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
#include <migraphx/config.hpp>
#if MIGRAPHX_HAS_EXECUTORS
#include <execution>
#else
#include <migraphx/simple_par_for.hpp>
#endif
#include <algorithm>
#include <mutex>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace detail {
struct exception_list
{
std::vector<std::exception_ptr> exceptions;
std::mutex m;
void add_exception()
{
std::lock_guard<std::mutex> guard(m);
exceptions.push_back(std::current_exception());
}
template <class F>
auto collect(F f)
{
return [f, this](auto&&... xs) {
try
{
f(std::forward<decltype(xs)>(xs)...);
}
catch(...)
{
this->add_exception();
}
};
}
void throw_if_exception() const
{
if(not exceptions.empty())
std::rethrow_exception(exceptions.front());
}
};
} // namespace detail
template <class InputIt, class OutputIt, class UnaryOperation>
OutputIt par_transform(InputIt first1, InputIt last1, OutputIt d_first, UnaryOperation unary_op)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::transform(std::execution::par, first1, last1, d_first, std::move(unary_op));
#else
simple_par_for(last1 - first1, [&](auto i) { d_first[i] = unary_op(first1[i]); });
return d_first + (last1 - first1);
#endif
}
template <class InputIt1, class InputIt2, class OutputIt, class BinaryOperation>
OutputIt par_transform(
InputIt1 first1, InputIt1 last1, InputIt2 first2, OutputIt d_first, BinaryOperation binary_op)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::transform(
std::execution::par, first1, last1, first2, d_first, std::move(binary_op));
#else
simple_par_for(last1 - first1, [&](auto i) { d_first[i] = binary_op(first1[i], first2[i]); });
return d_first + (last1 - first1);
#endif
}
template <class InputIt, class UnaryFunction>
void par_for_each(InputIt first, InputIt last, UnaryFunction f)
{
#if MIGRAPHX_HAS_EXECUTORS
// Propagate the exception
detail::exception_list ex;
std::for_each(std::execution::par, first, last, ex.collect(std::move(f)));
ex.throw_if_exception();
#else
simple_par_for(last - first, [&](auto i) { f(first[i]); });
#endif
}
template <class... Ts>
auto par_copy_if(Ts&&... xs)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::copy_if(std::execution::par, std::forward<Ts>(xs)...);
#else
return std::copy_if(std::forward<Ts>(xs)...);
#endif
}
template <class... Ts>
auto par_sort(Ts&&... xs)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::sort(std::execution::par, std::forward<Ts>(xs)...);
#else
return std::sort(std::forward<Ts>(xs)...);
#endif
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
......@@ -24,93 +24,23 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#include <thread>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cassert>
#include <migraphx/par.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct joinable_thread : std::thread
{
template <class... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...) // NOLINT
{
}
joinable_thread& operator=(joinable_thread&& other) = default;
joinable_thread(joinable_thread&& other) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <class F>
auto thread_invoke(std::size_t i, std::size_t tid, F f) -> decltype(f(i, tid))
{
f(i, tid);
}
template <class F>
auto thread_invoke(std::size_t i, std::size_t, F f) -> decltype(f(i))
{
f(i);
}
template <class F>
void par_for_impl(std::size_t n, std::size_t threadsize, F f)
{
if(threadsize <= 1)
{
for(std::size_t i = 0; i < n; i++)
thread_invoke(i, 0, f);
}
else
{
std::vector<joinable_thread> threads(threadsize);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0;
std::size_t tid = 0;
std::generate(threads.begin(), threads.end(), [=, &work, &tid] {
auto result = joinable_thread([=] {
std::size_t start = work;
std::size_t last = std::min(n, work + grainsize);
for(std::size_t i = start; i < last; i++)
{
thread_invoke(i, tid, f);
}
});
work += grainsize;
++tid;
return result;
});
assert(work >= n);
}
}
template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f)
void par_for(std::size_t n, F f)
{
const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
n / std::max<std::size_t>(1, min_grain));
par_for_impl(n, threadsize, f);
using iterator = basic_iota_iterator<id, std::size_t>;
par_for_each(iterator{0, {}}, iterator{n, {}}, f);
}
template <class F>
void par_for(std::size_t n, F f)
void par_for(std::size_t n, std::size_t, F f)
{
const int min_grain = 8;
par_for(n, min_grain, f);
par_for(n, f);
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -26,6 +26,7 @@
#include <string>
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -34,6 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
......@@ -60,7 +61,8 @@ struct MIGRAPHX_EXPORT shape
m(int32_type, int32_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLE_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLE_PAR_FOR_HPP
#include <thread>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cassert>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct joinable_thread : std::thread
{
template <class... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...) // NOLINT
{
}
joinable_thread& operator=(joinable_thread&& other) = default;
joinable_thread(joinable_thread&& other) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <class F>
auto thread_invoke(std::size_t i, std::size_t tid, F f) -> decltype(f(i, tid))
{
f(i, tid);
}
template <class F>
auto thread_invoke(std::size_t i, std::size_t, F f) -> decltype(f(i))
{
f(i);
}
template <class F>
void simple_par_for_impl(std::size_t n, std::size_t threadsize, F f)
{
if(threadsize <= 1)
{
for(std::size_t i = 0; i < n; i++)
thread_invoke(i, 0, f);
}
else
{
std::vector<joinable_thread> threads(threadsize);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0;
std::size_t tid = 0;
std::generate(threads.begin(), threads.end(), [=, &work, &tid] {
auto result = joinable_thread([=] {
std::size_t start = work;
std::size_t last = std::min(n, work + grainsize);
for(std::size_t i = start; i < last; i++)
{
thread_invoke(i, tid, f);
}
});
work += grainsize;
++tid;
return result;
});
assert(work >= n);
}
}
template <class F>
void simple_par_for(std::size_t n, std::size_t min_grain, F f)
{
const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
n / std::max<std::size_t>(1, min_grain));
simple_par_for_impl(n, threadsize, f);
}
template <class F>
void simple_par_for(std::size_t n, F f)
{
const int min_grain = 8;
simple_par_for(n, min_grain, f);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -28,25 +28,35 @@
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/float8.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPHX_DETAIL_DEFINE_TRAIT(trait) \
template <class X> \
struct trait : std::trait<X> \
{ \
};
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
};
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_floating_point);
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_arithmetic);
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_signed);
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e4m3fnuz)
template <class T>
using accumulator_type =
std::conditional_t<is_floating_point<T>{},
......
......@@ -26,7 +26,11 @@ find_package(Protobuf REQUIRED)
protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS onnx.proto)
add_library(onnx-proto STATIC ${PROTO_SRCS})
target_include_directories(onnx-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR})
target_compile_options(onnx-proto PRIVATE -w)
if(MSVC)
target_compile_options(onnx-proto PRIVATE /w)
else()
target_compile_options(onnx-proto PRIVATE -w)
endif()
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
......@@ -37,7 +41,10 @@ set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
migraphx_generate_export_header(migraphx_onnx)
rocm_set_soversion(migraphx_onnx ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_onnx)
target_link_libraries(migraphx_onnx PRIVATE onnx-proto "-Wl,--exclude-libs,ALL")
target_link_libraries(migraphx_onnx PRIVATE onnx-proto)
if(NOT WIN32)
target_link_libraries(migraphx_onnx PRIVATE "-Wl,--exclude-libs,ALL")
endif()
target_link_libraries(migraphx_onnx PUBLIC migraphx)
rocm_install_targets(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment