Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c2b41d9d
Commit
c2b41d9d
authored
Oct 31, 2023
by
Umang Yadav
Browse files
Add type in API, add MIGRAPHX_HIP_HOST_DEVICE
parent
c75fb295
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
480 additions
and
18 deletions
+480
-18
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+1
-0
src/include/migraphx/fp8e4m3fnuz.hpp
src/include/migraphx/fp8e4m3fnuz.hpp
+445
-0
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+2
-0
src/include/migraphx/type_traits.hpp
src/include/migraphx/type_traits.hpp
+15
-5
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
+12
-12
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+1
-0
tools/api/migraphx.h
tools/api/migraphx.h
+4
-1
No files found.
src/api/include/migraphx/migraphx.h
View file @
c2b41d9d
...
@@ -37,6 +37,7 @@
...
@@ -37,6 +37,7 @@
m(half_type, half) \
m(half_type, half) \
m(float_type, float) \
m(float_type, float) \
m(double_type, double) \
m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \
m(uint8_type, uint8_t) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
m(uint16_type, uint16_t) \
...
...
src/include/migraphx/fp8e4m3fnuz.hpp
0 → 100644
View file @
c2b41d9d
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_FP8E4M3FNUZ_HPP
#define MIGRAPHX_GUARD_RTGLIB_FP8E4M3FNUZ_HPP
/// Defines the Float8_e4m3fnuz type (8-bit floating-point) including
/// conversions to standard C types and basic arithmetic operations. Note that
/// arithmetic operations are implemented by converting to floating point and
/// performing the operation in float32.
///
/// Binary configuration remains the same as Float8_e4m3fn:
/// s eeee mmm
/// 1 sign bit
/// 4 exponent bits
/// 3 mantissa bits
///
/// The key differences versus Float8_e4m3fn are:
/// bias = 8
/// no infinities or negative zero
/// NaN only when sign bit is 1, rest all 0s
///
/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and
/// the existing Float8_e4m3fn implementation.
#include <type_traits>
#include <cmath>
#include <cstdint>
#include <climits>
#include <cstring>
#include <iosfwd>
#include <limits>
#include <sstream>
#include <iostream>
#include <migraphx/half.hpp>
#include <string>
#include <utility>
#if defined __HIP_PLATFORM_HCC__
// MIOpen by default does not have device code in the regular compilation paths,
// therefore, when this file is used from the host side, compilation takes much
// longer. By guarding the __device__ directive we can control that such compilation
// only happens for kernels which include this file.
#define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__
#else
#define MIGRAPHX_HIP_HOST_DEVICE
#endif
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wimplicit-int-float-conversion"
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier"
namespace
migraphx
{
namespace
detail
{
inline
float
MIGRAPHX_HIP_HOST_DEVICE
fp32_from_bits
(
uint32_t
w
)
{
union
{
uint32_t
as_bits
;
float
as_value
;
}
fp32
=
{
w
};
return
fp32
.
as_value
;
}
inline
uint32_t
MIGRAPHX_HIP_HOST_DEVICE
fp32_to_bits
(
float
f
)
{
union
{
float
as_value
;
uint32_t
as_bits
;
}
fp32
=
{
f
};
return
fp32
.
as_bits
;
}
/*
* Convert a 8-bit floating-point number in fp8 E4M3FNUZ format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format, in bit representation.
*
* @note The implementation doesn't use any floating-point operations.
*/
inline
float
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz_to_fp32_value
(
uint8_t
input
)
{
constexpr
float
e4m3fnuz_lut
[
256
]
=
{
0.0
f
,
0.0009765625
f
,
0.001953125
f
,
0.0029296875
f
,
0.00390625
f
,
0.0048828125
f
,
0.005859375
f
,
0.0068359375
f
,
0.0078125
f
,
0.0087890625
f
,
0.009765625
f
,
0.0107421875
f
,
0.01171875
f
,
0.0126953125
f
,
0.013671875
f
,
0.0146484375
f
,
0.015625
f
,
0.017578125
f
,
0.01953125
f
,
0.021484375
f
,
0.0234375
f
,
0.025390625
f
,
0.02734375
f
,
0.029296875
f
,
0.03125
f
,
0.03515625
f
,
0.0390625
f
,
0.04296875
f
,
0.046875
f
,
0.05078125
f
,
0.0546875
f
,
0.05859375
f
,
0.0625
f
,
0.0703125
f
,
0.078125
f
,
0.0859375
f
,
0.09375
f
,
0.1015625
f
,
0.109375
f
,
0.1171875
f
,
0.125
f
,
0.140625
f
,
0.15625
f
,
0.171875
f
,
0.1875
f
,
0.203125
f
,
0.21875
f
,
0.234375
f
,
0.25
f
,
0.28125
f
,
0.3125
f
,
0.34375
f
,
0.375
f
,
0.40625
f
,
0.4375
f
,
0.46875
f
,
0.5
f
,
0.5625
f
,
0.625
f
,
0.6875
f
,
0.75
f
,
0.8125
f
,
0.875
f
,
0.9375
f
,
1.0
f
,
1.125
f
,
1.25
f
,
1.375
f
,
1.5
f
,
1.625
f
,
1.75
f
,
1.875
f
,
2.0
f
,
2.25
f
,
2.5
f
,
2.75
f
,
3.0
f
,
3.25
f
,
3.5
f
,
3.75
f
,
4.0
f
,
4.5
f
,
5.0
f
,
5.5
f
,
6.0
f
,
6.5
f
,
7.0
f
,
7.5
f
,
8.0
f
,
9.0
f
,
10.0
f
,
11.0
f
,
12.0
f
,
13.0
f
,
14.0
f
,
15.0
f
,
16.0
f
,
18.0
f
,
20.0
f
,
22.0
f
,
24.0
f
,
26.0
f
,
28.0
f
,
30.0
f
,
32.0
f
,
36.0
f
,
40.0
f
,
44.0
f
,
48.0
f
,
52.0
f
,
56.0
f
,
60.0
f
,
64.0
f
,
72.0
f
,
80.0
f
,
88.0
f
,
96.0
f
,
104.0
f
,
112.0
f
,
120.0
f
,
128.0
f
,
144.0
f
,
160.0
f
,
176.0
f
,
192.0
f
,
208.0
f
,
224.0
f
,
240.0
f
,
std
::
numeric_limits
<
float
>::
quiet_NaN
(),
-
0.0009765625
f
,
-
0.001953125
f
,
-
0.0029296875
f
,
-
0.00390625
f
,
-
0.0048828125
f
,
-
0.005859375
f
,
-
0.0068359375
f
,
-
0.0078125
f
,
-
0.0087890625
f
,
-
0.009765625
f
,
-
0.0107421875
f
,
-
0.01171875
f
,
-
0.0126953125
f
,
-
0.013671875
f
,
-
0.0146484375
f
,
-
0.015625
f
,
-
0.017578125
f
,
-
0.01953125
f
,
-
0.021484375
f
,
-
0.0234375
f
,
-
0.025390625
f
,
-
0.02734375
f
,
-
0.029296875
f
,
-
0.03125
f
,
-
0.03515625
f
,
-
0.0390625
f
,
-
0.04296875
f
,
-
0.046875
f
,
-
0.05078125
f
,
-
0.0546875
f
,
-
0.05859375
f
,
-
0.0625
f
,
-
0.0703125
f
,
-
0.078125
f
,
-
0.0859375
f
,
-
0.09375
f
,
-
0.1015625
f
,
-
0.109375
f
,
-
0.1171875
f
,
-
0.125
f
,
-
0.140625
f
,
-
0.15625
f
,
-
0.171875
f
,
-
0.1875
f
,
-
0.203125
f
,
-
0.21875
f
,
-
0.234375
f
,
-
0.25
f
,
-
0.28125
f
,
-
0.3125
f
,
-
0.34375
f
,
-
0.375
f
,
-
0.40625
f
,
-
0.4375
f
,
-
0.46875
f
,
-
0.5
f
,
-
0.5625
f
,
-
0.625
f
,
-
0.6875
f
,
-
0.75
f
,
-
0.8125
f
,
-
0.875
f
,
-
0.9375
f
,
-
1.0
f
,
-
1.125
f
,
-
1.25
f
,
-
1.375
f
,
-
1.5
f
,
-
1.625
f
,
-
1.75
f
,
-
1.875
f
,
-
2.0
f
,
-
2.25
f
,
-
2.5
f
,
-
2.75
f
,
-
3.0
f
,
-
3.25
f
,
-
3.5
f
,
-
3.75
f
,
-
4.0
f
,
-
4.5
f
,
-
5.0
f
,
-
5.5
f
,
-
6.0
f
,
-
6.5
f
,
-
7.0
f
,
-
7.5
f
,
-
8.0
f
,
-
9.0
f
,
-
10.0
f
,
-
11.0
f
,
-
12.0
f
,
-
13.0
f
,
-
14.0
f
,
-
15.0
f
,
-
16.0
f
,
-
18.0
f
,
-
20.0
f
,
-
22.0
f
,
-
24.0
f
,
-
26.0
f
,
-
28.0
f
,
-
30.0
f
,
-
32.0
f
,
-
36.0
f
,
-
40.0
f
,
-
44.0
f
,
-
48.0
f
,
-
52.0
f
,
-
56.0
f
,
-
60.0
f
,
-
64.0
f
,
-
72.0
f
,
-
80.0
f
,
-
88.0
f
,
-
96.0
f
,
-
104.0
f
,
-
112.0
f
,
-
120.0
f
,
-
128.0
f
,
-
144.0
f
,
-
160.0
f
,
-
176.0
f
,
-
192.0
f
,
-
208.0
f
,
-
224.0
f
,
-
240.0
f
,
};
return
e4m3fnuz_lut
[
input
];
}
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E4M3FNUZ format, in bit representation.
*/
inline
uint8_t
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz_from_fp32_value
(
float
f
)
{
/*
* Binary representation of 256.0f, which is the first value not representable
* (i.e. the first value which would overflow in to the sign bit, resulting in
* a NaN) in fp8e4m3fnuz range:
* 1 0000 000 - fp8e4m3fnuz
* 0 10000110 00000000000000000000000 - fp32
*/
constexpr
uint32_t
fnuz_max
=
UINT32_C
(
0x87
)
<<
23
;
/*
* A mask for converting fp32 numbers lower than fp8e4m3fnuz normal range
* into denormalized representation.
* magic number: ((127 - 8) + (23 - 3) + 1)
*/
constexpr
uint32_t
denorm_mask
=
UINT32_C
(
0x8C
)
<<
23
;
uint32_t
f_bits
=
fp32_to_bits
(
f
);
uint32_t
result
=
0u
;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const
uint32_t
sign
=
f_bits
&
UINT32_C
(
0x80000000
);
/*
* Set sign bit to 0
*/
f_bits
^=
sign
;
if
(
f_bits
>=
fnuz_max
)
{
// NaN -- sign bit set to 1, rest 0s
return
0x80
;
}
if
(
f_bits
<
(
UINT32_C
(
121
)
<<
23
))
{
// Input number is smaller than 2^(-7), which is the smallest
// fp8e4m3fnuz normal number
f_bits
=
fp32_to_bits
(
fp32_from_bits
(
f_bits
)
+
fp32_from_bits
(
denorm_mask
));
result
=
static_cast
<
uint8_t
>
(
f_bits
-
denorm_mask
);
if
(
result
==
0
)
{
// fnuz types don't have negative zero.
return
0
;
}
}
else
{
// resulting mantissa is odd
uint8_t
mant_odd
=
(
f_bits
>>
20
)
&
1
;
// update exponent, rounding bias part 1
f_bits
+=
((
uint32_t
)(
8
-
127
)
<<
23
)
+
0x7FFFF
;
// rounding bias part 2
f_bits
+=
mant_odd
;
// take the bits!
result
=
static_cast
<
uint8_t
>
(
f_bits
>>
20
);
}
result
|=
sign
>>
24
;
return
result
;
}
}
// namespace detail
struct
alignas
(
1
)
fp8e4m3fnuz
{
uint8_t
x
;
struct
from_bits_t
{
};
__device__
__host__
static
constexpr
from_bits_t
from_bits
()
{
return
from_bits_t
();
}
__device__
__host__
fp8e4m3fnuz
()
:
x
(
0
)
{}
MIGRAPHX_HIP_HOST_DEVICE
constexpr
fp8e4m3fnuz
(
uint8_t
bits
,
from_bits_t
)
:
x
(
bits
)
{}
inline
explicit
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
(
float
value
)
:
x
(
detail
::
fp8e4m3fnuz_from_fp32_value
(
value
))
{
}
inline
explicit
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
(
migraphx
::
half
value
)
:
x
(
detail
::
fp8e4m3fnuz_from_fp32_value
(
float
(
value
)))
{
}
inline
MIGRAPHX_HIP_HOST_DEVICE
operator
float
()
const
{
return
detail
::
fp8e4m3fnuz_to_fp32_value
(
x
);
}
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
float
rhs
)
{
x
=
detail
::
fp8e4m3fnuz_from_fp32_value
(
rhs
);
return
*
this
;
}
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
migraphx
::
half
rhs
)
{
x
=
detail
::
fp8e4m3fnuz_from_fp32_value
(
float
(
rhs
));
return
*
this
;
}
inline
bool
MIGRAPHX_HIP_HOST_DEVICE
isnan
()
const
{
return
x
==
0b10000000
;
}
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
+=
(
float
rhs
)
{
x
=
detail
::
fp8e4m3fnuz_from_fp32_value
(
rhs
+
float
(
x
));
return
*
this
;
}
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
-=
(
float
rhs
)
{
x
=
detail
::
fp8e4m3fnuz_from_fp32_value
(
rhs
-
float
(
x
));
return
*
this
;
}
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
*=
(
float
rhs
)
{
x
=
detail
::
fp8e4m3fnuz_from_fp32_value
(
rhs
*
float
(
x
));
return
*
this
;
}
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
/=
(
float
rhs
)
{
x
=
detail
::
fp8e4m3fnuz_from_fp32_value
(
rhs
/
float
(
x
));
return
*
this
;
}
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
fp8e4m3fnuz
&
value
)
{
out
<<
(
float
)(
value
);
return
out
;
}
}
// namespace migraphx
namespace
std
{
template
<
>
class
numeric_limits
<
migraphx
::
fp8e4m3fnuz
>
{
public:
static
constexpr
bool
is_specialized
=
true
;
static
constexpr
bool
is_signed
=
true
;
static
constexpr
bool
is_integer
=
false
;
static
constexpr
bool
is_exact
=
false
;
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
bool
has_quiet_NaN
=
true
;
static
constexpr
bool
has_signaling_NaN
=
false
;
static
constexpr
auto
has_denorm
=
true
;
static
constexpr
auto
has_denorm_loss
=
true
;
static
constexpr
auto
round_style
=
numeric_limits
<
float
>::
round_style
;
static
constexpr
bool
is_iec559
=
false
;
static
constexpr
bool
is_bounded
=
true
;
static
constexpr
bool
is_modulo
=
false
;
static
constexpr
int
digits
=
4
;
static
constexpr
int
digits10
=
0
;
static
constexpr
int
max_digits10
=
3
;
static
constexpr
int
radix
=
2
;
static
constexpr
int
min_exponent
=
-
5
;
static
constexpr
int
min_exponent10
=
-
1
;
static
constexpr
int
max_exponent
=
8
;
static
constexpr
int
max_exponent10
=
2
;
static
constexpr
auto
traps
=
numeric_limits
<
float
>::
traps
;
static
constexpr
auto
tinyness_before
=
false
;
static
constexpr
migraphx
::
fp8e4m3fnuz
min
()
{
return
migraphx
::
fp8e4m3fnuz
(
0x08
,
migraphx
::
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
migraphx
::
fp8e4m3fnuz
lowest
()
{
return
migraphx
::
fp8e4m3fnuz
(
0xFF
,
migraphx
::
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
migraphx
::
fp8e4m3fnuz
max
()
{
return
migraphx
::
fp8e4m3fnuz
(
0x7F
,
migraphx
::
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
migraphx
::
fp8e4m3fnuz
epsilon
()
{
return
migraphx
::
fp8e4m3fnuz
(
0x28
,
migraphx
::
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
migraphx
::
fp8e4m3fnuz
round_error
()
{
return
migraphx
::
fp8e4m3fnuz
(
0x38
,
migraphx
::
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
migraphx
::
fp8e4m3fnuz
infinity
()
{
// NaN (no infinities)
return
migraphx
::
fp8e4m3fnuz
(
0x80
,
migraphx
::
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
migraphx
::
fp8e4m3fnuz
quiet_NaN
()
{
return
migraphx
::
fp8e4m3fnuz
(
0x80
,
migraphx
::
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
migraphx
::
fp8e4m3fnuz
denorm_min
()
{
return
migraphx
::
fp8e4m3fnuz
(
0x01
,
migraphx
::
fp8e4m3fnuz
::
from_bits
());
}
};
template
<
class
T
>
struct
common_type
<
migraphx
::
fp8e4m3fnuz
,
T
>
:
std
::
common_type
<
float
,
T
>
// NOLINT
{
};
template
<
class
T
>
struct
common_type
<
T
,
migraphx
::
fp8e4m3fnuz
>
:
std
::
common_type
<
float
,
T
>
// NOLINT
{
};
template
<
>
struct
common_type
<
migraphx
::
fp8e4m3fnuz
,
migraphx
::
fp8e4m3fnuz
>
{
using
type
=
float
;
};
template
<
>
struct
common_type
<
migraphx
::
fp8e4m3fnuz
,
migraphx
::
half
>
{
using
type
=
float
;
};
template
<
>
struct
common_type
<
migraphx
::
half
,
migraphx
::
fp8e4m3fnuz
>
{
using
type
=
float
;
};
}
// namespace std
#pragma clang diagnostic pop
#endif
src/include/migraphx/shape.hpp
View file @
c2b41d9d
...
@@ -34,6 +34,7 @@
...
@@ -34,6 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/fp8e4m3fnuz.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
...
@@ -53,6 +54,7 @@ struct MIGRAPHX_EXPORT shape
...
@@ -53,6 +54,7 @@ struct MIGRAPHX_EXPORT shape
m(half_type, half) \
m(half_type, half) \
m(float_type, float) \
m(float_type, float) \
m(double_type, double) \
m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \
m(uint8_type, uint8_t) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
m(uint16_type, uint16_t) \
...
...
src/include/migraphx/type_traits.hpp
View file @
c2b41d9d
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <migraphx/fp8e4m3fnuz.hpp>
#include <type_traits>
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
...
@@ -32,21 +33,30 @@
...
@@ -32,21 +33,30 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
#define MIGRAPHX_DETAIL_
EXTEND
_TRAIT
_FOR
(trait
, T
) \
#define MIGRAPHX_DETAIL_
DEFINE
_TRAIT(trait) \
template <class X> \
template <class X> \
struct trait : std::trait<X> \
struct trait : std::trait<X> \
{ \
{ \
}; \
};
\
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <> \
template <> \
struct trait<T> : std::true_type \
struct trait<T> : std::true_type \
{ \
{ \
};
};
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_floating_point
);
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_arithmetic
);
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_signed
);
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
fp8e4m3fnuz
)
template
<
class
T
>
template
<
class
T
>
using
accumulator_type
=
using
accumulator_type
=
std
::
conditional_t
<
is_floating_point
<
T
>
{},
std
::
conditional_t
<
is_floating_point
<
T
>
{},
...
...
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
View file @
c2b41d9d
...
@@ -146,7 +146,7 @@ __device__ __host__ T to_hip_type(T x)
...
@@ -146,7 +146,7 @@ __device__ __host__ T to_hip_type(T x)
// Hip doens't support __fp16
// Hip doens't support __fp16
inline
__device__
__host__
float
to_hip_type
(
gpu_half
x
)
{
return
x
;
}
inline
__device__
__host__
float
to_hip_type
(
gpu_half
x
)
{
return
x
;
}
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
#define MIGRAPHX_DETAIL_
DEVICE_
EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
template <class X> \
struct trait : std::trait<X> \
struct trait : std::trait<X> \
{ \
{ \
...
@@ -157,9 +157,9 @@ inline __device__ __host__ float to_hip_type(gpu_half x) { return x; }
...
@@ -157,9 +157,9 @@ inline __device__ __host__ float to_hip_type(gpu_half x) { return x; }
{ \
{ \
};
};
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
__fp16
)
MIGRAPHX_DETAIL_
DEVICE_
EXTEND_TRAIT_FOR
(
is_floating_point
,
__fp16
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
__fp16
)
MIGRAPHX_DETAIL_
DEVICE_
EXTEND_TRAIT_FOR
(
is_signed
,
__fp16
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
__fp16
)
MIGRAPHX_DETAIL_
DEVICE_
EXTEND_TRAIT_FOR
(
is_arithmetic
,
__fp16
)
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/target.cpp
View file @
c2b41d9d
...
@@ -98,6 +98,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -98,6 +98,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
ctx
.
set_exhaustive_tune_flag
(
options
.
exhaustive_tune
);
ctx
.
set_exhaustive_tune_flag
(
options
.
exhaustive_tune
);
std
::
set
<
shape
::
type_t
>
unsupported_types
(
shape
::
types
().
begin
(),
shape
::
types
().
end
());
std
::
set
<
shape
::
type_t
>
unsupported_types
(
shape
::
types
().
begin
(),
shape
::
types
().
end
());
unsupported_types
.
erase
(
shape
::
type_t
::
float_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
float_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
float8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
half_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
half_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
bool_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
bool_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int8_type
);
...
...
tools/api/migraphx.h
View file @
c2b41d9d
...
@@ -37,6 +37,7 @@
...
@@ -37,6 +37,7 @@
m(half_type, half) \
m(half_type, half) \
m(float_type, float) \
m(float_type, float) \
m(double_type, double) \
m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \
m(uint8_type, uint8_t) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
m(uint16_type, uint16_t) \
...
@@ -70,7 +71,9 @@ typedef enum
...
@@ -70,7 +71,9 @@ typedef enum
}
migraphx_shape_datatype_t
;
}
migraphx_shape_datatype_t
;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
<%
generate_c_header
()
%>
<%
generate_c_header
()
%>
#ifdef __cplusplus
#ifdef __cplusplus
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment