Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d734871c
Commit
d734871c
authored
Nov 07, 2023
by
Umang Yadav
Browse files
Add rocblas's implemenation
parent
9254df13
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
910 additions
and
0 deletions
+910
-0
src/include/migraphx/migraphx_float8.hpp
src/include/migraphx/migraphx_float8.hpp
+592
-0
src/include/migraphx/migraphx_hip_f8_impl.hpp
src/include/migraphx/migraphx_hip_f8_impl.hpp
+307
-0
src/include/migraphx/type_traits.hpp
src/include/migraphx/type_traits.hpp
+10
-0
test/CMakeLists.txt
test/CMakeLists.txt
+1
-0
No files found.
src/include/migraphx/migraphx_float8.hpp
0 → 100644
View file @
d734871c
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_FLOAT8_HPP
#define MIGRAPHX_FLOAT8_HPP
#ifdef __HIP_PLATFORM_HCC__
#define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__
#else
#define MIGRAPHX_HIP_HOST_DEVICE
#endif
#define MIGRAPHX_HIP_HOST __host__
#define MIGRAPHX_HIP_DEVICE __device__
// We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
#ifndef __HIPCC_RTC__
#include <cmath>
#include <cstdint>
#include <climits>
#include <cstring>
#include <iosfwd>
#include <limits>
#include <sstream>
#include <iostream>
#include <string>
#include <utility>
#include <migraphx/type_traits.hpp>
#else
#include <migraphx/kernels/type_traits.hpp>
#endif
namespace
migraphx_hip_f8_impl
{
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
MIGRAPHX_HIP_HOST_DEVICE
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
=
false
,
uint32_t
rng
=
0
);
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
MIGRAPHX_HIP_HOST_DEVICE
T
cast_from_f8
(
uint8_t
x
);
}
// namespace migraphx_hip_f8_impl
#include "migraphx_hip_f8_impl.hpp"
namespace
migraphx_fp8
{
enum
class
migraphx_hip_f8_rounding_mode
{
standard
,
stochastic
};
enum
class
hip_f8_type
{
bf8
=
0
,
// s1e5m2
fp8
=
1
// s1e4m3
};
template
<
migraphx_fp8
::
hip_f8_type
T
=
migraphx_fp8
::
hip_f8_type
::
fp8
>
struct
MIGRAPHX_EXPORT
migraphx_f8
{
uint8_t
data
;
// default constructor
MIGRAPHX_HIP_HOST_DEVICE
migraphx_f8
()
=
default
;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// device specific optimized F8 down-conversion code
template
<
bool
stochastic_rounding
=
false
>
static
MIGRAPHX_HIP_DEVICE
uint8_t
cast_to_f8_from_f32
(
float
v
,
uint32_t
rng
=
0
)
{
uint8_t
i8data
;
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// NOTE: not endian independent
}
val
;
uint32_t
ival
=
0
;
val
.
fval
=
v
;
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
if
constexpr
(
T
==
migraphx_fp8
::
hip_f8_type
::
fp8
)
{
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
240.0
,
-
240.0
);
}
else
{
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
57344.0
,
-
57344.0
);
}
#endif
if
(
stochastic_rounding
)
{
if
constexpr
(
T
==
migraphx_fp8
::
hip_f8_type
::
fp8
)
{
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
}
else
{
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
}
val
.
i32val
=
ival
;
i8data
=
val
.
i8val
[
0
];
// little endian
}
else
// RNE CVT
{
if
constexpr
(
T
==
migraphx_fp8
::
hip_f8_type
::
fp8
)
{
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
}
else
{
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0}
val
.
i32val
=
ival
;
i8data
=
val
.
i8val
[
0
];
}
return
i8data
;
}
}
#endif // __gfx940__
// constructor from float
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// NOTE: ON-DEVICE... always optimal bias
explicit
MIGRAPHX_HIP_DEVICE
migraphx_f8
(
float
v
,
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
rm
=
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
{
// runtime branch, use cast_to_f8_from_f32 if want to avoid it
if
(
rm
==
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
stochastic
)
data
=
cast_to_f8_from_f32
<
true
>
(
v
,
rng
);
else
data
=
cast_to_f8_from_f32
<
false
>
(
v
);
}
// Host only implementation using s/w simulation
explicit
MIGRAPHX_HIP_HOST
#else
// both Host and DEVICE for non-gfx940 using s/w simulation
explicit
MIGRAPHX_HIP_HOST_DEVICE
#endif
migraphx_f8
(
float
v
,
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
rm
=
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
{
if
constexpr
(
T
==
migraphx_fp8
::
hip_f8_type
::
fp8
)
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx_hip_f8_impl
::
cast_to_f8
<
3
,
4
,
float
,
true
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx_hip_f8_impl
::
cast_to_f8
<
3
,
4
,
float
,
true
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
stochastic
),
rng
);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
}
else
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx_hip_f8_impl
::
cast_to_f8
<
2
,
5
,
float
,
true
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx_hip_f8_impl
::
cast_to_f8
<
2
,
5
,
float
,
true
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
stochastic
),
rng
);
#endif // rocblas_F8_downcast_clipping}
}
}
// Constructor from half
explicit
MIGRAPHX_HIP_HOST_DEVICE
migraphx_f8
(
migraphx
::
half
v
,
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
rm
=
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
:
migraphx_f8
((
float
)
v
,
rm
,
rng
)
{
}
// constructor from int
explicit
MIGRAPHX_HIP_HOST_DEVICE
migraphx_f8
(
int
v
,
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
rm
=
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
:
migraphx_f8
((
float
)
v
,
rm
,
rng
)
{
}
// constructor from double
explicit
MIGRAPHX_HIP_HOST_DEVICE
migraphx_f8
(
double
v
,
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
rm
=
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
:
migraphx_f8
((
float
)
v
,
rm
,
rng
)
{
}
// convert to float
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// upcast using device specific intrinsic
explicit
inline
MIGRAPHX_HIP_DEVICE
operator
float
()
const
{
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
data
);
// upcast
if
constexpr
(
T
==
migraphx_fp8
::
hip_f8_type
::
fp8
)
{
asm
volatile
(
"v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
:
"=v"
(
fval
)
:
"v"
(
i32val
));
}
else
{
asm
volatile
(
"v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0"
:
"=v"
(
fval
)
:
"v"
(
i32val
));
}
return
fval
;
}
explicit
inline
MIGRAPHX_HIP_HOST
operator
float
()
const
#else // non gfx940
explicit
inline
MIGRAPHX_HIP_HOST_DEVICE
operator
float
()
const
#endif
{
if
constexpr
(
T
==
migraphx_fp8
::
hip_f8_type
::
fp8
)
{
return
migraphx_hip_f8_impl
::
cast_from_f8
<
3
,
4
,
float
,
true
/*negative_zero_nan*/
>
(
data
);
}
// else
return
migraphx_hip_f8_impl
::
cast_from_f8
<
2
,
5
,
float
,
true
/*negative_zero_nan*/
>
(
data
);
}
// convert to half
explicit
inline
MIGRAPHX_HIP_HOST_DEVICE
operator
migraphx
::
half
()
const
{
return
migraphx
::
half
(
float
(
*
this
));
// convert to float, then convert to f16
}
// check for zero
inline
MIGRAPHX_HIP_HOST_DEVICE
bool
is_zero
()
const
{
return
data
==
0x00
;
}
// check for nan
inline
MIGRAPHX_HIP_HOST_DEVICE
bool
is_nan
()
const
{
return
data
==
0x80
;
}
// check for inf
inline
MIGRAPHX_HIP_HOST_DEVICE
bool
is_inf
()
const
{
return
data
==
0x80
;
}
// assignment overloading only from the same F8 types
inline
__host__
__device__
migraphx_f8
&
operator
=
(
const
migraphx_f8
&
a
)
{
data
=
a
.
data
;
return
*
this
;
}
};
/*
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const migraphx_f8& f8) { return os << float(f8); }
inline std::ostream& operator<<(std::ostream& os, const migraphx_bf8& bf8)
{
return os << float(bf8);
}
// all + operator overloading with mixed types
// mixed types, always converts to f32, does computation in f32, and returns float
inline __host__ __device__ float operator+(const float fa, migraphx_f8 b)
{
return (fa + float(b));
}
inline __host__ __device__ float operator+(const float fa, migraphx_bf8 b)
{
return (fa + float(b));
}
inline __host__ __device__ float operator+(migraphx_f8 a, const float fb)
{
return (float(a) + fb);
}
inline __host__ __device__ float operator+(migraphx_bf8 a, const float fb)
{
return (float(a) + fb);
}
inline __host__ __device__ float operator+(migraphx_f8 a, migraphx_bf8 b)
{
return (float(a) + float(b));
}
inline __host__ __device__ float operator+(migraphx_bf8 a, migraphx_f8 b)
{
return (float(a) + float(b));
}
inline __host__ __device__ migraphx_f8 operator+(migraphx_f8 a, migraphx_f8 b)
{
return migraphx_f8(float(a) + float(b));
}
inline __host__ __device__ migraphx_bf8 operator+(migraphx_bf8 a, migraphx_bf8 b)
{
return migraphx_bf8(float(a) + float(b));
}
inline __host__ __device__ migraphx_f8& operator+=(migraphx_f8& a, migraphx_f8 b)
{
return a = migraphx_f8(float(a) + float(b));
}
inline __host__ __device__ migraphx_bf8& operator+=(migraphx_bf8& a, migraphx_bf8 b)
{
return a = migraphx_bf8(float(a) + float(b));
}
// overloading multiplication, always returns float,
inline __host__ __device__ float operator*(migraphx_f8 a, migraphx_f8 b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(float a, migraphx_f8 b) { return (a * float(b)); }
inline __host__ __device__ float operator*(migraphx_f8 a, float b) { return (float(a) * b); }
inline __host__ __device__ float operator*(int32_t a, migraphx_f8 b)
{
return ((float)a * float(b));
}
inline __host__ __device__ float operator*(double a, migraphx_f8 b)
{
return ((float)a * float(b));
}
inline __host__ __device__ float operator*(migraphx_bf8 a, migraphx_bf8 b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(float a, migraphx_bf8 b) { return (a * float(b)); }
inline __host__ __device__ float operator*(migraphx_bf8 a, float b) { return (float(a) * b); }
inline __host__ __device__ float operator*(int32_t a, migraphx_bf8 b)
{
return ((float)a * float(b));
}
inline __host__ __device__ float operator*(double a, migraphx_bf8 b)
{
return ((float)a * float(b));
}
// overloading for mixed f8 and bf8 types
inline __host__ __device__ float operator*(migraphx_f8 a, migraphx_bf8 b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(migraphx_bf8 a, migraphx_f8 b)
{
return float(a) * float(b);
}
// all - operator overloading with mixed types
// mixed types, always converts to f32, does computation in f32, and returns float
inline __host__ __device__ float operator-(const float fa, migraphx_f8 b)
{
return (fa - float(b));
}
inline __host__ __device__ float operator-(const float fa, migraphx_bf8 b)
{
return (fa - float(b));
}
inline __host__ __device__ float operator-(migraphx_f8 a, const float fb)
{
return (float(a) - fb);
}
inline __host__ __device__ float operator-(migraphx_bf8 a, const float fb)
{
return (float(a) - fb);
}
inline __host__ __device__ float operator-(migraphx_f8 a, migraphx_bf8 b)
{
return (float(a) - float(b));
}
inline __host__ __device__ float operator-(migraphx_bf8 a, migraphx_f8 b)
{
return (float(a) - float(b));
}
inline __host__ __device__ migraphx_f8 operator-(migraphx_f8 a, migraphx_f8 b)
{
return migraphx_f8(float(a) - float(b));
}
inline __host__ __device__ migraphx_bf8 operator-(migraphx_bf8 a, migraphx_bf8 b)
{
return migraphx_bf8(float(a) - float(b));
}
inline __host__ __device__ migraphx_f8& operator-=(migraphx_f8& a, migraphx_f8 b)
{
return a = migraphx_f8(float(a) - float(b));
}
inline __host__ __device__ migraphx_bf8& operator-=(migraphx_bf8& a, migraphx_bf8 b)
{
return a = migraphx_bf8(float(a) - float(b));
}
// overloading division, always returns float,
inline __host__ __device__ float operator/(migraphx_f8 a, migraphx_f8 b)
{
return float(a) / float(b);
}
inline __host__ __device__ float operator/(float a, migraphx_f8 b) { return (a / float(b)); }
inline __host__ __device__ float operator/(migraphx_f8 a, float b) { return (float(a) / b); }
inline __host__ __device__ float operator/(int32_t a, migraphx_f8 b)
{
return ((float)a / float(b));
}
inline __host__ __device__ float operator/(double a, migraphx_f8 b)
{
return ((float)a / float(b));
}
inline __host__ __device__ float operator/(migraphx_bf8 a, migraphx_bf8 b)
{
return float(a) / float(b);
}
inline __host__ __device__ float operator/(float a, migraphx_bf8 b) { return (a / float(b)); }
inline __host__ __device__ float operator/(migraphx_bf8 a, float b) { return (float(a) / b); }
inline __host__ __device__ float operator/(int32_t a, migraphx_bf8 b)
{
return ((float)a / float(b));
}
inline __host__ __device__ float operator/(double a, migraphx_bf8 b)
{
return ((float)a / float(b));
}
// overloading for mixed f8 and bf8 types
inline __host__ __device__ float operator/(migraphx_f8 a, migraphx_bf8 b)
{
return float(a) / float(b);
}
inline __host__ __device__ float operator/(migraphx_bf8 a, migraphx_f8 b)
{
return float(a) / float(b);
}
// overloading for compare
inline __host__ __device__ bool operator==(migraphx_f8 a, migraphx_f8 b)
{
return (a.data == b.data);
}
inline __host__ __device__ bool operator==(migraphx_bf8 a, migraphx_bf8 b)
{
return (a.data == b.data);
}
inline __host__ __device__ bool operator!=(migraphx_f8 a, migraphx_f8 b)
{
return (a.data != b.data);
}
inline __host__ __device__ bool operator!=(migraphx_bf8 a, migraphx_bf8 b)
{
return (a.data != b.data);
}
// ================ Explicit downcasting to support different rounding (RNE, SR)
// =============== NOTE: we going to remove all assignment operator overloading from other
// types and enforce this explicit_downcast function to make any roudning behavior default
// We have to explicitly call this function with SR flag
template <typename T,
typename Ta,
bool stochastic_rounding,
typename std::enable_if<migraphx::is_same<T, Ta>{}, int>::type = 0>
inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng = 0)
{
// same type, no conversion
return a;
}
// Use h/w intrinsic and optimized version when __gfx940__
template <typename T,
typename Ta,
bool stochastic_rounding,
typename std::enable_if<(!(migraphx::is_same<T, Ta>{}) &&
(migraphx::is_same<T, migraphx_f8>{} ||
migraphx::is_same<T, migraphx_bf8>{})),
int>::type = 0>
inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// NOTE: we are directly calling cast_to_f8_from_f32 instead of constructor to optimize
// away one runtime branch
T val;
if(migraphx::is_same<T, migraphx_f8>::value)
val.data = migraphx_f8::cast_to_f8_from_f32<stochastic_rounding>(float(a), rng);
else
val.data = migraphx_bf8::cast_to_bf8_from_f32<stochastic_rounding>(float(a), rng);
return val;
#else // non gfx940
return T(float(a),
stochastic_rounding ? migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic
: migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
rng);
#endif // __gfx940__
}
// NOTE NOTE: The above code is good if we don't consider HIP-GEMM code and only consider
// the quantization However, if we need HIP-GEMM for fall-back, we would need explicit_cast
// handles Tacc=f32 to To=f16/bf16 conversion
template <typename T,
typename Ta,
bool stochastic_rounding,
typename std::enable_if<(!(migraphx::is_same<T, Ta>{}) &&
!(migraphx::is_same<T, migraphx_f8>{} ||
migraphx::is_same<T, migraphx_bf8>{})),
int>::type = 0>
inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng)
{
// the return type is not a F8 types, no SR for those types
// not sure if we have direct conversion, so converting to float first
// no effect if the input type is float
return T(float(a));
}
*/
}
// namespace migraphx_fp8
/*
namespace std {
inline migraphx_f8 sin(migraphx_f8 a) { return migraphx_f8(sinf(float(a))); }
inline migraphx_f8 cos(migraphx_f8 a) { return migraphx_f8(cosf(float(a))); }
inline migraphx_bf8 sin(migraphx_bf8 a) { return migraphx_bf8(sinf(float(a))); }
inline migraphx_bf8 cos(migraphx_bf8 a) { return migraphx_bf8(cosf(float(a))); }
__device__ __host__ constexpr migraphx_f8 real(const migraphx_f8& a) { return a; }
__device__ __host__ constexpr migraphx_bf8 real(const migraphx_bf8& a) { return a; }
} // namespace std
*/
// =================================================================================================
#endif // MIGRAPHX_FLOAT8_HPP
src/include/migraphx/migraphx_hip_f8_impl.hpp
0 → 100644
View file @
d734871c
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_HIP_FP8_IMPL_HPP
#define MIGRAPHX_HIP_FP8_IMPL_HPP
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
namespace
migraphx_hip_f8_impl
{
// #ifdef __HIP_PLATFORM_HCC__
// __device__ inline int clz(uint32_t x) { return __clz(x); }
// #else
// __host__ inline int clz(uint32_t x) { return __builtin_clz(x); }
// #endif
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
MIGRAPHX_HIP_HOST_DEVICE
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
,
uint32_t
rng
)
{
constexpr
bool
is_half
=
migraphx
::
is_same
<
T
,
migraphx
::
half
>
{};
constexpr
bool
is_float
=
migraphx
::
is_same
<
T
,
float
>
{};
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
static_assert
(
is_half
||
is_float
,
"Only half and float can be cast to f8"
);
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
uint32_t
x
;
if
(
sizeof
(
T
)
==
4
)
x
=
reinterpret_cast
<
uint32_t
&>
(
_x
);
else
x
=
reinterpret_cast
<
uint16_t
&>
(
_x
);
uint32_t
head
,
mantissa
;
int
exponent
,
bias
;
uint32_t
sign
;
if
(
sizeof
(
T
)
==
4
)
{
head
=
x
&
0xFF800000
;
mantissa
=
x
&
0x7FFFFF
;
exponent
=
(
head
>>
23
)
&
0xFF
;
sign
=
head
>>
31
;
bias
=
127
;
}
else
{
head
=
x
&
0xFC00
;
mantissa
=
x
&
0x3FF
;
exponent
=
(
head
>>
10
)
&
0x1F
;
sign
=
head
>>
15
;
bias
=
15
;
}
uint32_t
signed_inf
=
(
sign
<<
7
)
+
(((
1
<<
we
)
-
1
)
<<
wm
);
// Deal with inf and NaNs
if
(
negative_zero_nan
)
{
if
(
sizeof
(
T
)
==
4
)
{
if
((
x
&
0x7F800000
)
==
0x7F800000
)
return
0x80
;
}
else
{
// if(__hisinf(x) || __hisnan(x))
if
((
x
&
0x7C00
)
==
0x7C00
)
return
0x80
;
}
}
else
{
if
(
sizeof
(
T
)
==
4
)
{
if
((
x
&
0x7F800000
)
==
0x7F800000
)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
else
{
if
((
x
&
0x7C00
)
==
0x7C00
)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
}
if
(
x
==
0
)
return
0
;
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
const
int
f8_bias
=
(
1
<<
(
we
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
const
int
f8_denormal_act_exponent
=
1
-
f8_bias
;
// actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
f8_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. In
this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
f8_denormal_act_exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa
+=
(
1
<<
mfmt
);
// Add the implicit 1 into mantissa
}
bool
midpoint
=
(
mantissa
&
((
1
<<
(
mfmt
-
wm
+
exponent_diff
))
-
1
))
==
(
1
<<
(
mfmt
-
wm
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint.
*/
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
mfmt
);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
f8_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
f8_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
uint32_t
drop_mask
=
(
1
<<
(
mfmt
-
wm
))
-
1
;
bool
odd
=
mantissa
&
(
1
<<
(
mfmt
-
wm
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
f8_exponent
==
0
)
{
if
((
1
<<
mfmt
)
&
mantissa
)
{
f8_exponent
=
1
;
// denormal overflow to become normal, promote exponent
}
}
else
{
if
((
1
<<
(
mfmt
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
f8_exponent
++
;
}
}
mantissa
>>=
(
mfmt
-
wm
);
// above range: quantize to maximum possible float of the same sign
const
int
max_exp
=
(
1
<<
we
)
-
(
negative_zero_nan
?
1
:
2
);
if
(
f8_exponent
>
max_exp
)
{
if
(
clip
)
{
mantissa
=
(
1
<<
wm
)
-
1
;
f8_exponent
=
max_exp
;
}
else
{
return
signed_inf
;
}
}
if
(
f8_exponent
==
0
&&
mantissa
==
0
)
return
negative_zero_nan
?
0
:
(
sign
<<
7
);
mantissa
&=
(
1
<<
wm
)
-
1
;
return
(
sign
<<
7
)
|
(
f8_exponent
<<
wm
)
|
mantissa
;
}
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
MIGRAPHX_HIP_HOST_DEVICE
T
cast_from_f8
(
uint8_t
x
)
{
constexpr
bool
is_half
=
migraphx
::
is_same
<
T
,
migraphx
::
half
>
{};
constexpr
bool
is_float
=
migraphx
::
is_same
<
T
,
float
>
{};
static_assert
(
is_half
||
is_float
,
"only half and float are supported"
);
constexpr
int
weo
=
is_half
?
5
:
8
;
constexpr
int
wmo
=
is_half
?
10
:
(
is_float
?
23
:
7
);
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
if
(
is_half
)
{
const
uint16_t
ihInf
=
0x7C00
;
const
uint16_t
ihNegInf
=
0xFC00
;
const
uint16_t
ihNaN
=
0x7C01
;
const
uint16_t
ihNeg0
=
0x8000
;
fInf
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihInf
);
fNegInf
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihNegInf
);
fNaN
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihNaN
);
fNeg0
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihNeg0
);
}
else
if
(
is_float
)
{
const
uint32_t
ifInf
=
0x7F800000
;
const
uint32_t
ifNegInf
=
0xFF800000
;
const
uint32_t
ifNaN
=
0x7F800001
;
const
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
reinterpret_cast
<
const
float
&>
(
ifInf
);
fNegInf
=
reinterpret_cast
<
const
float
&>
(
ifNegInf
);
fNaN
=
reinterpret_cast
<
const
float
&>
(
ifNaN
);
fNeg0
=
reinterpret_cast
<
const
float
&>
(
ifNeg0
);
}
if
(
x
==
0
)
return
0
;
uint32_t
sign
=
x
>>
7
;
uint32_t
mantissa
=
x
&
((
1
<<
wm
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
wm
;
if
(
negative_zero_nan
)
{
if
(
x
==
0x80
)
return
fNaN
;
}
else
{
if
(
x
==
0x80
)
return
fNeg0
;
if
(
exponent
==
((
1
<<
we
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
}
typename
migraphx
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
if
(
we
==
5
&&
is_half
&&
!
negative_zero_nan
)
{
retval
=
x
<<
8
;
return
reinterpret_cast
<
const
T
&>
(
retval
);
}
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
// subnormal input
if
(
exponent
==
0
)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
__builtin_clz
(
mantissa
)
-
(
32
-
wm
);
mantissa
<<=
sh
;
exponent
+=
1
-
sh
;
mantissa
&=
((
1
<<
wm
)
-
1
);
}
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
wmo
-
wm
;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
{
mantissa
|=
1
<<
wmo
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
if
(
sizeof
(
T
)
==
2
)
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
else
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
return
reinterpret_cast
<
const
T
&>
(
retval
);
}
}
// namespace migraphx_hip_f8_impl
#pragma clang diagnostic pop
#endif // MIGRAPHX_HIP_FP8_IMPL_HPP
src/include/migraphx/type_traits.hpp
View file @
d734871c
...
...
@@ -49,6 +49,16 @@ MIGRAPHX_DETAIL_DEFINE_TRAIT(is_floating_point);
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_arithmetic
);
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_signed
);
template
<
class
T
,
class
U
>
struct
is_same
:
std
::
is_same
<
T
,
U
>
{
};
template
<
bool
B
,
class
T
,
class
U
>
struct
conditional
:
std
::
conditional
<
B
,
T
,
U
>
{
};
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
...
...
test/CMakeLists.txt
View file @
d734871c
...
...
@@ -150,6 +150,7 @@ function(test_headers PREFIX)
list
(
REMOVE_ITEM HEADERS
${
CMAKE_SOURCE_DIR
}
/src/targets/gpu/include/migraphx/gpu/ck.hpp
)
endif
()
list
(
REMOVE_ITEM HEADERS
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/migraphx_hip_f8_impl.hpp
)
foreach
(
HEADER
${
HEADERS
}
)
file
(
RELATIVE_PATH HEADER_REL
${
CMAKE_SOURCE_DIR
}
${
HEADER
}
)
string
(
MAKE_C_IDENTIFIER
${
HEADER_REL
}
TEST_NAME
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment