Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
27598fab
Commit
27598fab
authored
Nov 09, 2023
by
Umang Yadav
Browse files
changes to make it work with hiprtc
parent
dc9c9784
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
90 additions
and
71 deletions
+90
-71
src/include/migraphx/migraphx_float8.hpp
src/include/migraphx/migraphx_float8.hpp
+84
-64
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
+3
-4
No files found.
src/include/migraphx/migraphx_float8.hpp
View file @
27598fab
...
...
@@ -28,32 +28,36 @@
#pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wmacro-redefined"
#pragma clang diagnostic ignored "-Wc++20-extensions"
#endif
#endif
// __clang__
#if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__))
// need to include hip_runtime.h otherwise it complains about __host__ and __device__
#ifndef __HIPCC_RTC__
#include <hip/hip_runtime.h>
#else
#if defined(MIGRAPHX_JIT_USE_HIPRTC)
#include <migraphx/kernels/hip.hpp>
#else
#include <hip/hip_runtime.h>
#endif
#define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__
#define MIGRAPHX_HIP_HOST __host__
#else
#define MIGRAPHX_HIP_HOST_DEVICE
#define MIGRAPHX_HIP_HOST
#endif
#endif
// HIP_PLATFORM_AMD
#define MIGRAPHX_HIP_DEVICE __device__
#ifndef MIGRAPHX_FP8_FNUZ
#define MIGRAPHX_FP8_FNUZ true
#endif
#endif
// MIGRAPHX_FP8_FNUZ
// We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
#ifndef __HIPCC_RTC__
#if defined(MIGRAPHX_JIT_USE_HIPRTC)
#include <migraphx/kernels/types.hpp>
using
uint8_t
=
migraphx
::
uint8_t
;
using
uint16_t
=
migraphx
::
uint16_t
;
using
uint32_t
=
migraphx
::
uint32_t
;
#else
#include <cmath>
#include <cstdint>
#include <climits>
...
...
@@ -92,6 +96,9 @@ enum class hip_f8_type
fp8
=
1
// s1e4m3
};
template
<
typename
T
>
class
NumericLimits
;
template
<
migraphx_fp8
::
hip_f8_type
T
=
migraphx_fp8
::
hip_f8_type
::
fp8
>
struct
hip_f8
{
...
...
@@ -388,7 +395,7 @@ struct hip_f8
inline
MIGRAPHX_HIP_HOST_DEVICE
constexpr
bool
operator
==
(
const
hip_f8
&
rhs
)
const
{
if
((
rhs
.
is_zero
()
&&
this
->
is_zero
())
||
(
fabs
(
rhs
-
*
this
)
<
std
::
n
umeric
_l
imits
<
hip_f8
<
T
>>::
epsilon
()))
(
fabs
(
rhs
-
*
this
)
<
migraphx_fp8
::
N
umeric
L
imits
<
hip_f8
<
T
>>::
epsilon
()))
return
true
;
else
if
(
rhs
.
is_nan
()
||
rhs
.
is_inf
()
||
this
->
is_nan
()
||
this
->
is_inf
())
return
false
;
...
...
@@ -411,7 +418,7 @@ struct hip_f8
}
};
#ifndef
_
_HIP
CC_
RTC
__
#ifndef
MIGRAPHX_JIT_USE
_HIPRTC
// Special operator overloading
template
<
migraphx_fp8
::
hip_f8_type
T
>
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
migraphx_fp8
::
hip_f8
<
T
>&
rhs
)
...
...
@@ -463,6 +470,69 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Lowest()
using
fp8e4m3fnuz
=
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
;
template
<
>
class
NumericLimits
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
{
public:
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
epsilon
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
(
float
(
0.0625
));
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
quiet_NaN
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
(
static_cast
<
uint8_t
>
(
MIGRAPHX_FP8_FNUZ
?
0X80
:
0x79
));
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
max
()
{
return
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
();
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
min
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
(
-
1.0
f
)
*
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
();
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
lowest
()
{
return
migraphx_fp8
::
F8_Lowest
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
();
}
};
template
<
>
class
NumericLimits
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
{
public:
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
epsilon
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
float
(
0.125
));
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
quiet_NaN
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
static_cast
<
uint8_t
>
(
MIGRAPHX_FP8_FNUZ
?
0X80
:
0x7d
));
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
max
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
());
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
min
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
float
(
-
1.0
f
))
*
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
();
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
lowest
()
{
return
migraphx_fp8
::
F8_Lowest
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
();
}
};
/*
// Use h/w intrinsic and optimized version when __gfx940__
template <typename T,
...
...
@@ -511,6 +581,7 @@ inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng)
*/
}
// namespace migraphx_fp8
// define numeric limits for the new data type
#ifndef MIGRAPHX_JIT_USE_HIPRTC
namespace
std
{
inline
bool
isfinite
(
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
x
)
// NOLINT
{
...
...
@@ -524,66 +595,14 @@ inline bool isfinite(migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> x) //
template
<
>
class
numeric_limits
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
:
public
migraphx_fp8
::
NumericLimits
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
{
public:
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
epsilon
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
(
float
(
0.0625
));
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
quiet_NaN
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
(
static_cast
<
uint8_t
>
(
MIGRAPHX_FP8_FNUZ
?
0X80
:
0x79
));
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
max
()
{
return
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
();
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
min
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
(
-
1.0
f
)
*
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
();
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
lowest
()
{
return
migraphx_fp8
::
F8_Lowest
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
();
}
};
template
<
>
class
numeric_limits
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
:
public
migraphx_fp8
::
NumericLimits
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
{
public:
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
epsilon
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
float
(
0.125
));
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
quiet_NaN
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
static_cast
<
uint8_t
>
(
MIGRAPHX_FP8_FNUZ
?
0X80
:
0x7d
));
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
max
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
());
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
min
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
float
(
-
1.0
f
))
*
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
();
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
lowest
()
{
return
migraphx_fp8
::
F8_Lowest
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
();
}
};
template
<
class
T
>
...
...
@@ -603,6 +622,7 @@ struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx_fp8::fp8e4m3fnuz>
};
}
// namespace std
#endif
// =================================================================================================
#if defined(__clang__)
#pragma clang diagnostic pop
...
...
src/targets/gpu/compile_hip.cpp
View file @
27598fab
...
...
@@ -199,7 +199,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
{
hiprtc_program
prog
(
std
::
move
(
srcs
));
auto
options
=
split_string
(
params
,
' '
);
options
.
push_back
(
"-DMIGRAPHX_USE_HIPRTC=1"
);
options
.
push_back
(
"-DMIGRAPHX_
JIT_
USE_HIPRTC=1"
);
// remove following three compilation flags for HIPRTC once fixes from hipRTC are available in
if
(
enabled
(
MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS
{}))
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
View file @
27598fab
...
...
@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP
#define MIGRAPHX_GUARD_KERNELS_HIP_HPP
#ifndef MIGRAPHX_USE_HIPRTC
#ifndef MIGRAPHX_
JIT_
USE_HIPRTC
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
27598fab
...
...
@@ -24,9 +24,9 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/migraphx_float8.hpp>
namespace
migraphx
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
View file @
27598fab
...
...
@@ -23,12 +23,11 @@
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/kernels/hip.hpp>
namespace
migraphx
{
#if defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS) and defined(MIGRAPHX_USE_HIPRTC)
#if defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS) and defined(MIGRAPHX_
JIT_
USE_HIPRTC)
using
int8_t
=
signed
char
;
using
uint8_t
=
unsigned
char
;
using
int16_t
=
signed
short
;
...
...
@@ -37,7 +36,7 @@ using int32_t = signed int;
using
uint32_t
=
unsigned
int
;
using
int64_t
=
signed
long
long
;
using
uint64_t
=
unsigned
long
long
;
#elif defined(MIGRAPHX_USE_HIPRTC)
#elif defined(MIGRAPHX_
JIT_
USE_HIPRTC)
using
int8_t
=
__hip_int8_t
;
using
uint8_t
=
__hip_uint8_t
;
using
int16_t
=
__hip_int16_t
;
...
...
@@ -55,7 +54,7 @@ using int32_t = std::int32_t;
using
uint32_t
=
std
::
uint32_t
;
using
int64_t
=
std
::
int64_t
;
using
uint64_t
=
std
::
uint64_t
;
#endif // MIGRAPHX_USE_HIPRTC
#endif // MIGRAPHX_
JIT_
USE_HIPRTC
using
index_int
=
uint32_t
;
using
diff_int
=
int32_t
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment