"model/vscode:/vscode.git/clone" did not exist on "543240fb5f0eb1a5443fd2b45f857e7dd4dcbfed"
Commit fed23ec7 authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'develop' into uif2-initial

parents 9d933920 225873aa
......@@ -89,7 +89,7 @@ requests==2.28.2
# via
# pygithub
# sphinx
rocm-docs-core==0.27.0
rocm-docs-core==0.28.0
# via -r requirements.in
smmap==5.0.0
# via gitdb
......
......@@ -44,7 +44,8 @@
m(int32_type, int32_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on
#ifdef __cplusplus
......
......@@ -134,8 +134,9 @@ fs::path dynamic_loader::path(void* address)
{
HMODULE module = nullptr;
if(GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
static_cast<LPCSTR>(address), &module) == 0)
GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
static_cast<LPCSTR>(address),
&module) == 0)
{
auto err = GetLastError();
MIGRAPHX_THROW("Unable to obtain module handle, error = " + std::to_string(err));
......
......@@ -219,9 +219,8 @@ struct find_pointwise_reshape_pointwise
auto reshape_input = [&](const auto& ins_to_insert) {
return [&](auto input) {
auto c = m.insert_instruction(ins_to_insert, make_op("contiguous"), input);
return m.insert_instruction(
ins_to_insert, make_op("reshape", {{"dims", cd.dims}}), c);
ins_to_insert, make_op("reshape", {{"dims", cd.dims}}), input);
};
};
auto x_inputs = x_ins->inputs();
......
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
#include <migraphx/config.hpp>
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <typename To, typename From>
inline constexpr To bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
#if defined(__GNUC__) and !defined(__clang__)
return MIGRAPHX_CONST_FOLD(*reinterpret_cast<To*>(&fr));
#else
return __builtin_bit_cast(To, fr);
#endif
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
#endif // MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
// We are clipping/saturation in down conversion by default. Unclipped version is not tested and
// shouldn't be used without having enough tests.
// logic is based on clipping table from here : https://onnx.ai/onnx/technical/float8.html#cast
// NOLINTNEXTLINE
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
#include <cmath>
#include <cstdint>
#include <climits>
#include <cstring>
#include <iosfwd>
#include <limits>
#include <sstream>
#include <iostream>
#include <string>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/float8_impl.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace fp8 {
enum class rounding_mode
{
standard, // standard rounding is doing RNE -- round to nearest even
stochastic
};
enum class f8_type
{
bf8 = 0, // s1e5m2
fp8 = 1 // s1e4m3
};
template <typename T, bool FNUZ = true>
class numeric_limits;
template <migraphx::fp8::f8_type T = migraphx::fp8::f8_type::fp8, bool FNUZ = true>
struct float8
{
uint8_t data = 0x00;
// default constructor
constexpr float8() = default;
// default copy constructor
constexpr float8(const float8& y) = default;
struct from_bits_t
{
};
static constexpr from_bits_t from_bits() { return from_bits_t(); }
explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {}
explicit constexpr float8(
float v,
migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard,
uint32_t rng = 0)
{
if constexpr(T == migraphx::fp8::f8_type::fp8)
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
}
else
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data = migraphx::fp8::impl::
cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>(
v, (rm == migraphx::fp8::rounding_mode::stochastic), rng);
#endif // rocblas_F8_downcast_clipping}
}
}
inline constexpr operator float() const
{
if constexpr(T == migraphx::fp8::f8_type::fp8)
{
return migraphx::fp8::impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(data);
} // else
return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data);
}
inline constexpr bool is_zero() const
{
if constexpr(FNUZ)
{
return data == 0x00;
}
else
{
return (data == 0x00) or (data == 0x80);
}
}
inline constexpr bool is_nan() const
{
if constexpr(FNUZ)
{
return data == 0x80;
}
else
{
if(T == migraphx::fp8::f8_type::bf8)
{
return (data == 0x7D) or (data == 0x7E) or (data == 0x7F) or (data == 0xFD) or
(data == 0xFE) or (data == 0xFF);
}
else
{
return (data == 0x7F) or (data == 0xFF);
}
}
}
inline constexpr bool is_inf() const
{
if constexpr(FNUZ)
{
return data == 0x80;
}
else
{
if(T == migraphx::fp8::f8_type::bf8)
{
return (data == 0x7C) or (data == 0xFC);
}
else
{
// no infinities in e4m3fn, represent them as NaNs
return (data == 0x7F) or (data == 0xFF);
}
}
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \
constexpr float8& operator unary_op(const float8& rhs) \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \
return *this; \
} \
constexpr float8& operator unary_op(const float& rhs) \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \
return *this; \
}
MIGRAPHX_FP8_UNARY_OP(*=, *)
MIGRAPHX_FP8_UNARY_OP(-=, -)
MIGRAPHX_FP8_UNARY_OP(+=, +)
MIGRAPHX_FP8_UNARY_OP(/=, /)
inline constexpr float8& operator=(const float8& rhs) = default;
inline constexpr float8& operator=(float8&& rhs) noexcept = default;
inline constexpr float8& operator=(float rhs)
{
*this = static_cast<float8>(rhs);
return *this;
}
inline constexpr bool operator==(const float8& rhs) const
{
if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf())
return false;
else if((rhs.is_zero() and this->is_zero()) or (this->data == rhs.data))
return true;
return false;
}
inline constexpr bool operator<(const float8& rhs) const
{
const auto we = static_cast<float>(*this);
const auto them = static_cast<float>(rhs);
return we < them;
}
inline constexpr bool operator>(const float8& rhs) const
{
const auto we = static_cast<float>(*this);
const auto them = static_cast<float>(rhs);
return we > them;
}
};
// https://onnx.ai/onnx/technical/float8.html
using fp8e4m3fn = float8<migraphx::fp8::f8_type::fp8, false>;
using fp8e5m2 = float8<migraphx::fp8::f8_type::bf8, false>;
using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>;
using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
/*
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \
inline constexpr U operator binary_op(const T& lhs, const T& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
// TODO: these should return floats for binary ops
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP_GEN_FOR(T) \
MIGRAPHX_FP8_BINARY_OP(*, T, T) \
MIGRAPHX_FP8_BINARY_OP(-, T, T) \
MIGRAPHX_FP8_BINARY_OP(/, T, T) \
MIGRAPHX_FP8_BINARY_OP(+, T, T) \
MIGRAPHX_FP8_BINARY_OP(==, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<, T, bool) \
MIGRAPHX_FP8_BINARY_OP(!=, T, bool)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e5m2)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e4m3fn)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e5m2fnuz)
MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e4m3fnuz)
*/
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const fp8e4m3fnuz& rhs)
{
return os << static_cast<float>(rhs);
}
inline fp8e4m3fnuz fabs(fp8e4m3fnuz v)
{
v.data = v.data & 0x7F; // NOLINT
return v;
}
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const fp8e4m3fn& rhs)
{
return os << static_cast<float>(rhs);
}
inline fp8e4m3fn fabs(fp8e4m3fn v)
{
v.data = v.data & 0x7F; // NOLINT
return v;
}
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const fp8e5m2fnuz& rhs)
{
return os << static_cast<float>(rhs);
}
inline fp8e5m2fnuz fabs(fp8e5m2fnuz v)
{
v.data = v.data & 0x7F; // NOLINT
return v;
}
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const fp8e5m2& rhs)
{
return os << static_cast<float>(rhs);
}
inline fp8e5m2 fabs(fp8e5m2 v)
{
v.data = v.data & 0x7F; // NOLINT
return v;
}
template <>
class numeric_limits<fp8e4m3fnuz>
{
public:
static constexpr bool has_infinity = false;
static constexpr fp8e4m3fnuz epsilon() { return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); }
// NOLINTNEXTLINE
static constexpr fp8e4m3fnuz quiet_NaN() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); }
static constexpr fp8e4m3fnuz max() { return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01
static constexpr fp8e4m3fnuz min() { return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits()); }
static constexpr fp8e4m3fnuz lowest() { return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits()); }
};
template <>
class numeric_limits<fp8e4m3fn>
{
public:
static constexpr bool has_infinity = false;
static constexpr fp8e4m3fn epsilon() { return fp8e4m3fn(0x20, fp8e4m3fn::from_bits()); }
// NOLINTNEXTLINE
static constexpr fp8e4m3fn quiet_NaN() { return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits()); }
static constexpr fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01
static constexpr fp8e4m3fn min() { return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); }
static constexpr fp8e4m3fn lowest() { return fp8e4m3fn(0xFE, fp8e4m3fn::from_bits()); }
};
template <>
class numeric_limits<fp8e5m2fnuz>
{
public:
static constexpr bool has_infinity = false;
static constexpr fp8e5m2fnuz epsilon() { return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); }
static constexpr fp8e5m2fnuz quiet_NaN() // NOLINT
{
return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits());
}
static constexpr fp8e5m2fnuz max() { return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
static constexpr fp8e5m2fnuz min() { return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); }
static constexpr fp8e5m2fnuz lowest() { return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits()); }
};
template <>
class numeric_limits<fp8e5m2>
{
public:
static constexpr bool has_infinity = true;
static constexpr fp8e5m2 epsilon() { return fp8e5m2(0x34, fp8e5m2::from_bits()); }
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
static constexpr fp8e5m2 quiet_NaN() { return fp8e5m2(0xFF, fp8e5m2::from_bits()); } // NOLINT
static constexpr fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
static constexpr fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); }
static constexpr fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); }
// 7C and FC both are infinity
static constexpr fp8e5m2 infinity() { return fp8e5m2(0x7C, fp8e5m2::from_bits()); }
};
} // namespace fp8
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
// =================================================================================================
// define numeric limits for the new data type
// NOLINTBEGIN
namespace std {
#define MIGRAPHX_FP8_STD_OVERLOADS(T) \
inline bool isfinite(T x) { return not x.is_inf() and not x.is_nan(); } \
inline bool isnan(T x) { return x.is_nan(); } \
template <> \
class numeric_limits<T> : public migraphx::fp8::numeric_limits<T> \
{ \
}; \
template <class U> \
struct common_type<T, U> : std::common_type<float, U> \
{ \
}; \
template <class U> \
struct common_type<U, T> : std::common_type<float, U> \
{ \
}; \
template <> \
struct common_type<T, T> \
{ \
using type = T; \
};
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fn)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2fnuz)
} // namespace std
// NOLINTEND
// =================================================================================================
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP
#include <algorithm>
#include <cstdint>
#include <type_traits>
#include <migraphx/config.hpp>
#include <migraphx/bit_cast.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace fp8 {
namespace impl {
// NOLINTBEGIN
template <uint32_t Wm, uint32_t We, typename T, bool NegativeZeroNan, bool Clip>
constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0)
{
constexpr bool is_float = std::is_same<T, float>::value;
// half is not supported for now
constexpr bool is_half = false;
static_assert(Wm + We == 7, "Wm+We==7");
static_assert(is_float or is_half, "Only float can be cast to f8");
const uint32_t mfmt = (sizeof(T) == 4) ? 23 : 10;
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type x;
if constexpr(sizeof(T) == 4)
x = migraphx::bit_cast<uint32_t>(f_x);
else
x = migraphx::bit_cast<uint16_t>(f_x);
uint32_t head = 0;
uint32_t mantissa = 0;
int exponent = 0;
uint32_t bias = 0;
uint32_t sign = 0;
if constexpr(sizeof(T) == 4)
{
head = x & 0xFF800000;
mantissa = x & 0x7FFFFF;
exponent = (head >> 23) & 0xFF;
sign = head >> 31;
bias = 127;
}
else
{
head = x & 0xFC00;
mantissa = x & 0x3FF;
exponent = (head >> 10) & 0x1F;
sign = head >> 15;
bias = 15;
}
uint32_t signed_inf = (sign << 7) + (((1 << We) - 1) << Wm);
uint32_t signed_all_ones = (sign << 7) + ((((1 << We) - 1) << Wm) + ((1 << Wm) - 1));
// Calcualte maximum singed value FLT_MAX, FLT_MIN
uint32_t signed_max = signed_all_ones;
if(not NegativeZeroNan)
signed_max = (Wm == 2) ? (signed_max - 4) : (signed_max - 1);
// Deal with inf and NaNs
if(NegativeZeroNan) // For the FNUZ cases, it is simple just return NaNs
{
if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or
(sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00)))
return 0x80;
}
else
{
// calculate most common NaN mantissa for FP8, which is all Ones in binary
uint32_t nan_mantissa = 1;
for(auto i = 1; i < Wm; ++i)
{
nan_mantissa |= (nan_mantissa << 1);
}
if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or
(sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00)))
{
// infinity
if(mantissa == 0)
{
if(sign == 0)
return (Wm == 2) ? 0x7B : 0x7E;
else
return (Wm == 2) ? 0xFB : 0xFE;
}
else // NaNs
return signed_inf + nan_mantissa;
}
}
// handle positive zero
if(x == 0)
return 0;
// handle negative zero
else if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000))
{
return NegativeZeroNan ? 0 : 0x80; // For FNUZ types neg zero is just positive zero
}
/* First need to check if it is normal or denorm as there is a difference of implict 1
Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
RNE, no need to add rng. Then probably need to check whether there is carry and adjust
exponent and mantissa again*/
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
const int f8_bias = (1 << (We - 1u)) - 1 + (NegativeZeroNan ? 1 : 0);
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
/* act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
f8_exponent is the converted f8 exponent with bias encoding
exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
the difference needs to be adjusted and mantissa shifted*/
int act_exponent = 0;
int f8_exponent = 0;
int exponent_diff = 0;
if(exponent == 0 and mantissa != 0)
{ // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal
has exponent bias 15 while bf8 with FNUZ has exponent bias 16. It means that there are some
numbers in fp16 denormal but they are bf8 (FNUZ) normals - smallest bf8 (FNUZ) normal is
2^-15. fp16 numbers where exponent==0 (actual exponent -14) and highest bit of mantissa is 1
are bf8 (FNUZ) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = 1 - bias;
exponent_diff = f8_denormal_act_exponent -
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
}
else
{ // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias;
if(act_exponent <= f8_denormal_act_exponent)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 FNUZ mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 FNUZ */
exponent_diff = f8_denormal_act_exponent - act_exponent;
}
else
{ // both fp32/fp16 and f8 are in normal range
exponent_diff =
0; // exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa += (1u << mfmt); // Add the implicit 1 into mantissa
}
// need to know whether the number is right in the middle of two adjacent fp8 numbers. use max
// value of 31 to avoid undefined behaviour
bool midpoint = (mantissa & ((1u << std::min(31u, mfmt - Wm + exponent_diff)) - 1)) ==
(1u << std::min(31u, mfmt - Wm + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint.
*/
if(exponent_diff > 0)
mantissa >>= std::min(31u, uint32_t(exponent_diff));
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << mfmt);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
f8_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
uint32_t drop_mask = (1u << (mfmt - Wm)) - 1;
bool odd =
mantissa & (1u << (mfmt - Wm)); // if the least significant bit that is not truncated is 1
/*
This part is doing rounding by adding mantissa part that is going to get dropped.
e.g. if the dropped part for less than 0.5 than it would round down.
if the dropped part is more than 0.5 then it would round up by rolling carry to LSB of retained
mantissa.
For the mid point when bit pattern is like this for Odd: `xy1:10000000` for Odd and
`xy0:10000000` for the Even. where `:` is delimiter for dropped v/s retained part.
For the odd case :
this will add xy1:10000000 + 000:10000000 which would roll over carry to LSB of retained
part making it RNE.
For the even case : this will add xy0:10000000 + 000:01111111 which would
round down and keep number Even
*/
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
// Now we deal with overflow
if(f8_exponent == 0 and ((1 << mfmt) & mantissa))
{
f8_exponent = 1; // denormal overflow to become normal, promote exponent
}
else if((1 << (mfmt + 1)) & mantissa)
{
mantissa >>= 1;
f8_exponent++;
}
mantissa >>= (mfmt - Wm);
// above range: quantize to maximum possible float of the same sign
// for e5m2 case, max_exp is 14, since exp = 15 is reserved for Infs and Nans
const int max_exp = (1 << We) - ((NegativeZeroNan or Wm == 3) ? 1 : 2);
if(f8_exponent > max_exp)
{
if(Clip)
return signed_max;
else
{
// https://onnx.ai/onnx/technical/float8.html#cast
if(NegativeZeroNan)
return 0x80;
else
return (Wm == 2) ? signed_inf : signed_all_ones;
}
}
if(f8_exponent == 0 and mantissa == 0)
return NegativeZeroNan ? 0 : (sign << 7);
mantissa &= (1 << Wm) - 1;
return (sign << 7) | (f8_exponent << Wm) | mantissa;
}
// NOLINTEND
template <uint32_t Wm, uint32_t We, typename T, bool NegativeZeroNan>
constexpr T cast_from_f8(uint8_t x)
{
// half is not supported for now
constexpr bool is_half = false;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_float or is_half, "Only float are supported");
constexpr int weo = is_half ? 5 : 8;
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
// NOLINTNEXTLINE
T f_inf, f_neg_inf, f_nan, f_neg0;
if constexpr(is_float)
{
const uint32_t if_inf = 0x7F800000;
const uint32_t if_neg_inf = 0xFF800000;
const uint32_t if_nan = 0x7F800001;
const uint32_t if_neg0 = 0x80000000;
f_inf = migraphx::bit_cast<float>(if_inf);
f_neg_inf = migraphx::bit_cast<float>(if_neg_inf);
f_nan = migraphx::bit_cast<float>(if_nan);
f_neg0 = migraphx::bit_cast<float>(if_neg0);
}
if(x == 0)
return 0;
uint32_t sign = x >> 7; // NOLINT
uint32_t mantissa = x & ((1 << Wm) - 1); // NOLINT
int exponent = (x & 0x7F) >> Wm; // NOLINT
if(NegativeZeroNan)
{
if(x == 0x80)
return f_nan;
}
else
{
if(x == 0x80)
return f_neg0;
if(exponent == ((1 << We) - 1) and Wm == 2) // NOLINT
return (mantissa == 0) ? (sign ? f_neg_inf : f_inf) : f_nan;
else if(Wm == 3 and (x == 0x7F or x == 0xFF))
return f_nan;
}
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
const int exp_low_cutoff =
(1 << (weo - 1)) - (1 << (We - 1)) + 1 - (NegativeZeroNan ? 1 : 0); // NOLINT
// subnormal input
if(exponent == 0)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __builtin_clz(mantissa) - (32 - Wm);
mantissa <<= sh; // NOLINT
exponent += 1 - sh;
mantissa &= ((1 << Wm) - 1); // NOLINT
}
exponent += exp_low_cutoff - 1;
mantissa <<= wmo - Wm; // NOLINT
// subnormal output (occurs when T=half, We=5, negative_zero_nan=true)
if(exponent <= 0)
{
mantissa |= 1 << wmo; // NOLINT
mantissa >>= 1 - exponent; // NOLINT
exponent = 0;
}
if(sizeof(T) == 2)
retval = (sign << 15) | (exponent << 10) | mantissa; // NOLINT
else
retval = (sign << 31) | (exponent << 23) | mantissa; // NOLINT
return migraphx::bit_cast<T>(retval);
}
} // namespace impl
} // namespace fp8
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL
......@@ -31,6 +31,7 @@
#include <half.hpp>
#endif
#include <migraphx/config.hpp>
#include <migraphx/float8.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -71,6 +72,18 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
{
};
template <>
struct common_type<migraphx::fp8::fp8e4m3fnuz, migraphx::half>
{
using type = float;
};
template <>
struct common_type<migraphx::half, migraphx::fp8::fp8e4m3fnuz>
{
using type = float;
};
template <>
struct common_type<migraphx::half, migraphx::half>
{
......
......@@ -591,6 +591,19 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
}
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins)
{
if(ins->name() != "@literal")
return false;
bool all_same = false;
ins->get_literal().visit([&](auto s) {
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) {
return float_equal(scale, s.front());
});
});
return all_same;
}
MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins)
{
if(ins->outputs().size() == 1)
......@@ -844,6 +857,12 @@ auto skip_broadcasts_converts(Ms... ms)
return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...);
}
template <class... Ms>
auto skip_broadcasts_transposes_contiguous(Ms... ms)
{
return skip(name("broadcast", "multibroadcast", "contiguous", "transpose"))(ms...);
}
template <class T>
inline auto has_value(T x, float tolerance = 1e-6)
{
......
......@@ -34,6 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
......@@ -60,7 +61,8 @@ struct MIGRAPHX_EXPORT shape
m(int32_type, int32_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
......
......@@ -28,25 +28,35 @@
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/float8.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPHX_DETAIL_DEFINE_TRAIT(trait) \
template <class X> \
struct trait : std::trait<X> \
{ \
};
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
};
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_floating_point);
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_arithmetic);
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_signed);
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e4m3fnuz)
template <class T>
using accumulator_type =
std::conditional_t<is_floating_point<T>{},
......
......@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
}
}
void lstm_transpose_inputs(onnx_parser::node_info& info, std::vector<instruction_ref>& args)
{
std::vector<int64_t> perm{1, 0, 2};
args[0] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[0]);
if(args.size() >= 6 and not args[5]->is_undefined())
{
args[5] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[5]);
}
if(args.size() >= 7 and not args[6]->is_undefined())
{
args[6] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[6]);
}
}
void lstm_transpose_outputs(onnx_parser::node_info& info,
instruction_ref& hidden_states,
instruction_ref& last_output,
instruction_ref& last_cell_output)
{
std::vector<int64_t> perm_hs{2, 0, 1, 3};
hidden_states =
info.add_instruction(make_op("transpose", {{"permutation", perm_hs}}), hidden_states);
std::vector<int64_t> perm_last{1, 0, 2};
last_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_output);
last_cell_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_cell_output);
}
struct parse_lstm : op_parser<parse_lstm>
{
std::vector<op_desc> operators() const { return {{"LSTM"}}; }
......@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
input_forget = parser.parse_value(info.attributes.at("input_forget")).at<int>();
}
int layout = 0;
if(contains(info.attributes, "layout"))
{
layout = parser.parse_value(info.attributes.at("layout")).at<int>();
}
// append undefined opeator to make 6 arguments
if(args.size() < 8)
{
......@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
args.insert(args.end(), 8 - args.size(), ins);
}
if(layout != 0)
{
lstm_transpose_inputs(info, args);
}
// first output for concatenation of hidden states
auto hidden_states = info.add_instruction(make_op("lstm",
{{"hidden_size", hidden_size},
......@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
auto last_cell_output =
info.add_instruction(make_op("rnn_last_cell_output"), hidden_states);
if(layout != 0)
{
lstm_transpose_outputs(info, hidden_states, last_output, last_cell_output);
}
return {hidden_states, last_output, last_cell_output};
}
};
......
......@@ -36,7 +36,7 @@ namespace onnx {
/*
*********************************************************************************
* Reference: see QLinearAdd in *
* Reference: see QLinearAdd, QLinearMul in *
* https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
*********************************************************************************
......@@ -49,6 +49,17 @@ namespace onnx {
This version of the operator has been available since version 1 of the 'com.microsoft' operator
set.
com.microsoft.QLinearMul
Performs element-wise binary multiplication on 8 bit data types (with Numpy-style broadcasting
support).
C = ((A - A_zero_point) * (B - B_zero_point)) * (A_scale * B_scale)/C_scale + C_zero_point
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator
set.
General definition of binary QLinear* ops:
Inputs (7 - 8)
A : T
First operand.
......@@ -88,15 +99,18 @@ namespace onnx {
*/
struct parse_qlinearadd : op_parser<parse_qlinearadd>
struct parse_qlinearbinary : op_parser<parse_qlinearbinary>
{
std::vector<op_desc> operators() const { return {{"QLinearAdd"}}; }
std::vector<op_desc> operators() const
{
return {{"QLinearAdd", "add"}, {"QLinearMul", "mul"}};
}
// basic type checking for QLinearAdd Operator
void check_inputs(const std::vector<instruction_ref>& args) const
// basic type checking for binary QLinear Operator
void check_inputs(const std::vector<instruction_ref>& args, const std::string& op_name) const
{
if(args.size() < 7)
MIGRAPHX_THROW("QLINEARADD: missing inputs");
MIGRAPHX_THROW(op_name + ": missing inputs");
const auto& in_a = args[0];
const auto& in_b = args[3];
......@@ -107,19 +121,19 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd>
auto type_a = sh_a.type();
auto type_b = sh_b.type();
if(type_a != migraphx::shape::int8_type and type_a != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARADD: unsupported input type");
MIGRAPHX_THROW(op_name + ": unsupported input type");
if(type_b != migraphx::shape::int8_type and type_b != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARADD: unsupported input type");
MIGRAPHX_THROW(op_name + ": unsupported input type");
if(type_a != type_b)
MIGRAPHX_THROW("QLINEARADD: mismatched input types");
MIGRAPHX_THROW(op_name + ": mismatched input types");
}
instruction_ref parse(const op_desc& /* opd */,
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
check_inputs(args);
check_inputs(args, opd.op_name);
// A
const auto& in_a = args[0];
......@@ -134,8 +148,8 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd>
const auto& in_zero_pt_b = args[5];
auto dquant_b = bcast_qdq_instr("dequantizelinear", in_b, in_scale_b, in_zero_pt_b, info);
// C = A + B
auto out_c = info.add_common_op("add", dquant_a, dquant_b);
// C = op(A, B)
auto out_c = info.add_common_op(opd.op_name, dquant_a, dquant_b);
const auto& in_scale_c = args[6];
......
......@@ -40,7 +40,7 @@
#include <migraphx/json.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/float8.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#endif
......@@ -144,6 +144,18 @@ struct npy_format_descriptor<half>
static constexpr auto name() { return _("half"); }
};
template <>
struct npy_format_descriptor<migraphx::fp8::fp8e4m3fnuz>
{
static std::string format()
{
// following: https://docs.python.org/3/library/struct.html#format-characters
// TODO: need to figure out correct encoding
return "z";
}
static constexpr auto name() { return _("fp8e4m3fnuz"); }
};
} // namespace detail
} // namespace pybind11
......
......@@ -941,15 +941,6 @@ struct find_splits
{
auto split = i->inputs()[split_idx];
assert(split->name() == "slice");
// Insert contiguous for reshapes
auto outputs = i->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), i);
m.replace_instruction(output, output->get_operator(), x);
}
m.replace_instruction(i, split->get_operator(), c);
}
......@@ -1181,13 +1172,6 @@ struct find_conv_dot_horiz_fusion
for(auto arg : range(start, last))
{
auto outputs = arg->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), arg);
m.replace_instruction(output, output->get_operator(), x);
}
int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction(
......@@ -1487,11 +1471,6 @@ struct find_split_reshape
slc_axis_len;
});
// insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard())
{
input = m.insert_instruction(std::next(input), make_op("contiguous"), input);
}
auto rsp_ins = m.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
......
......@@ -45,77 +45,145 @@ std::unordered_set<std::string> get_quantizable_op_names()
return s;
}
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins)
struct match_find_quantizable_ops
{
if(ins->name() != "@literal")
return false;
bool all_same = false;
ins->get_literal().visit([&](auto s) {
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) {
return float_equal(scale, s.front());
static bool
is_valid_scale(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
{
return scale->get_shape().scalar() or scale->get_shape().elements() == lens.at(axis);
}
static bool is_valid_zero_point(instruction_ref zp)
{
if(not zp->can_eval())
return false;
bool all_zeros = false;
zp->eval().visit([&](auto z) {
all_zeros =
std::all_of(z.begin(), z.end(), [&](auto val) { return float_equal(val, 0); });
});
});
return all_same;
}
return all_zeros;
}
struct match_find_quantizable_ops
{
static auto
scale_broadcast_op(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
{
if(scale->get_shape().scalar())
{
return migraphx::make_op("multibroadcast", {{"out_lens", lens}});
}
else
{
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}
}
static auto dequantizelinear_op(const std::string& name, const std::string& scale)
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
static auto
propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop)
{
auto qinp = dqins->inputs().front();
auto next_ins = dqins;
while(next_ins != qop)
{
if(next_ins->name() != "dequantizelinear")
{
qinp = m.insert_instruction(qop, next_ins->get_operator(), qinp);
}
next_ins = next_ins->outputs().front();
}
return qinp;
}
static auto dequantizelinear_op(const std::string& scale, const std::string& zp)
{
return match::name("dequantizelinear")(
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any().bind(name))),
match::arg(1)(match::skip_broadcasts(has_same_value().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::all_of(match::has_value(0)))));
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any())),
match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp))));
}
auto matcher() const
{
return match::name(get_quantizable_op_names())(
match::arg(0)(dequantizelinear_op("x1", "scale1")),
match::arg(1)(dequantizelinear_op("x2", "scale2")));
match::arg(0)(match::skip_broadcasts_transposes_contiguous(
dequantizelinear_op("scale1", "zp1").bind("dq1"))),
match::arg(1)(match::skip_broadcasts_transposes_contiguous(
dequantizelinear_op("scale2", "zp2").bind("dq2"))));
}
void apply(module& m, const match::matcher_result& r) const
{
auto qop = r.result;
auto q1 = r.instructions["x1"];
auto q2 = r.instructions["x2"];
auto dq1 = r.instructions["dq1"];
auto dq2 = r.instructions["dq2"];
auto scale1 = r.instructions["scale1"];
auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"];
// Only INT8 type currently supported
if(q1->get_shape().type() != migraphx::shape::int8_type or
q2->get_shape().type() != migraphx::shape::int8_type)
if(dq1->inputs().front()->get_shape().type() != migraphx::shape::int8_type or
dq2->inputs().front()->get_shape().type() != migraphx::shape::int8_type)
return;
double scale;
visit_all(scale1->get_literal(), scale2->get_literal())(
[&](const auto s1, const auto s2) { scale = s1.front() * s2.front(); });
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
if(not(is_valid_zero_point(zp1) and is_valid_zero_point(zp2)))
return;
// Only support scalar and 1D scales
if(scale1->get_shape().lens().size() != 1 or scale2->get_shape().lens().size() != 1)
return;
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto qop_args = qop->inputs();
qop_args.at(0) = q1;
qop_args.at(1) = q2;
qop_args.at(0) = propagate_quantized_ins(m, dq1, qop);
qop_args.at(1) = propagate_quantized_ins(m, dq2, qop);
instruction_ref dq;
instruction_ref dq_scale;
instruction_ref out_scale;
instruction_ref zero_point;
if(qop->name() == "convolution")
{
auto conv_val = qop->get_operator().to_value();
dq = m.insert_instruction(
qop, migraphx::make_op("quant_convolution", conv_val), qop_args);
auto out_lens = dq->get_shape().lens();
// Input scale should always be scalar and weight scale can be scalar or 1D of the
// same lens as the output channel dim (dim 1 in the output)
if(not(is_valid_scale(scale1, out_lens, 1) and is_valid_scale(scale2, out_lens, 1)))
return;
auto s1_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale1, out_lens, 1), scale1);
auto s2_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale2, out_lens, 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
}
else if(qop->name() == "dot")
{
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
auto out_lens = dq->get_shape().lens();
// For (..., M, N) x (..., N, K) dot, only support cases where quantization axis is M
// for input1 and K for input 2
if(not(is_valid_scale(scale1, out_lens, out_lens.size() - 2) and
is_valid_scale(scale2, out_lens, out_lens.size() - 1)))
return;
auto s1_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale1, out_lens, out_lens.size() - 2), scale1);
auto s2_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale2, out_lens, out_lens.size() - 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
}
auto ins_type = qop->get_shape().type();
dq_scale = m.add_literal(literal({ins_type}, {scale}));
auto lens = dq->get_shape().lens();
auto scale_mb =
m.insert_instruction(qop, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, out_scale);
m.replace_instruction(qop, dq);
}
};
......
......@@ -103,8 +103,6 @@ struct find_reshaper
auto input = mr.instructions["x"];
auto dims = ins->get_shape().lens();
if(not input->get_shape().standard())
input = m.insert_instruction(ins, make_op("contiguous"), input);
m.replace_instruction(ins, make_op("reshape", {{"dims", dims}}), input);
}
};
......@@ -475,9 +473,8 @@ struct find_resize
ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp);
auto mb_rsp = m.insert_instruction(
ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data);
auto std_mb = m.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end());
m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb);
m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), mb_rsp);
}
};
......@@ -626,9 +623,8 @@ struct find_transpose_contiguous_reshaper_unary
auto cont_ins = r.instructions["cont_ins"];
auto unary_op_name = ins->get_operator().name();
auto unary_ins = m.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins);
auto new_cont_ins = m.insert_instruction(cont_ins, make_op("contiguous"), unary_ins);
// older cont and reshape are removed by deadcode elimination
m.replace_instruction(ins, reshaper_ins->get_operator(), new_cont_ins);
m.replace_instruction(ins, reshaper_ins->get_operator(), unary_ins);
}
};
......
......@@ -251,14 +251,21 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
std::cout << std::string(src.content) << std::endl;
}
}
auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc);
auto fname = fs::path{"migraphx-hiprtc-driver"};
#ifdef _WIN32
auto driver = p.parent_path() / "migraphx-hiprtc-driver.exe";
#else
auto driver = p.parent_path().parent_path() / "bin" / "migraphx-hiprtc-driver";
fname.replace_extension(".exe");
#endif
auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc);
auto driver = p.parent_path() / fname;
bool found = fs::exists(driver);
if(not found)
{
driver = p.parent_path().parent_path() / "bin" / fname;
found = fs::exists(driver);
}
if(fs::exists(driver))
if(found)
{
value v;
v["srcs"] = to_value(hsrcs);
......
......@@ -146,20 +146,20 @@ __device__ __host__ T to_hip_type(T x)
// Hip doens't support __fp16
inline __device__ __host__ float to_hip_type(gpu_half x) { return x; }
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
#define MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
};
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16)
MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16)
MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16)
MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16)
} // namespace device
} // namespace gpu
......
......@@ -46,6 +46,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::fp8e4m3fnuz_type:
case shape::tuple_type:
case shape::bool_type:
case shape::uint16_type:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment