Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
63ed48e4
Commit
63ed48e4
authored
Nov 09, 2023
by
Umang Yadav
Browse files
Fixes
parent
711ff872
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
12 additions
and
23 deletions
+12
-23
src/include/migraphx/migraphx_float8.hpp
src/include/migraphx/migraphx_float8.hpp
+5
-15
src/include/migraphx/requires.hpp
src/include/migraphx/requires.hpp
+0
-3
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
...argets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
+2
-1
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+3
-2
test/py/test_shape.py
test/py/test_shape.py
+1
-1
test/verify/test_literal_limits.cpp
test/verify/test_literal_limits.cpp
+1
-1
No files found.
src/include/migraphx/migraphx_float8.hpp
View file @
63ed48e4
...
@@ -474,10 +474,12 @@ template <>
...
@@ -474,10 +474,12 @@ template <>
class
NumericLimits
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
class
NumericLimits
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
{
{
public:
public:
// TODO :figure out epsilon in Hex to make it constexpr
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
epsilon
()
epsilon
()
{
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
(
float
(
0.0625
));
return
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
(
0x28
,
migraphx_fp8
::
hip_f8
<>::
from_bits
());
}
}
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
...
@@ -493,13 +495,6 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
...
@@ -493,13 +495,6 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
return
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
();
return
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
();
}
}
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
min
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
(
-
1.0
f
)
*
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
();
}
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
lowest
()
lowest
()
{
{
...
@@ -521,7 +516,8 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
...
@@ -521,7 +516,8 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
epsilon
()
epsilon
()
{
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
float
(
0.125
));
return
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
(
0x34
,
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>::
from_bits
());
}
}
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
...
@@ -538,12 +534,6 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
...
@@ -538,12 +534,6 @@ class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
());
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
());
}
}
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
min
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
float
(
-
1.0
f
))
*
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
();
}
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
lowest
()
lowest
()
...
...
src/include/migraphx/requires.hpp
View file @
63ed48e4
...
@@ -38,9 +38,6 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs or true)...>> // NOLINT
...
@@ -38,9 +38,6 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs or true)...>> // NOLINT
template
<
bool
B
>
template
<
bool
B
>
using
bool_c
=
std
::
integral_constant
<
bool
,
B
>
;
using
bool_c
=
std
::
integral_constant
<
bool
,
B
>
;
template
<
class
From
,
class
To
>
using
is_convertible
=
std
::
is_convertible
<
From
,
To
>
;
#define MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y) x##y
#define MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y) x##y
#define MIGRAPHX_REQUIRES_CAT(x, y) MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y)
#define MIGRAPHX_REQUIRES_CAT(x, y) MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
View file @
63ed48e4
...
@@ -36,7 +36,8 @@ namespace migraphx {
...
@@ -36,7 +36,8 @@ namespace migraphx {
template
<
class
F
,
class
T
,
class
...
Ts
>
template
<
class
F
,
class
T
,
class
...
Ts
>
__device__
void
pointwise_tensor
(
index
idx
,
F
f
,
T
out
,
Ts
...
xs
)
__device__
void
pointwise_tensor
(
index
idx
,
F
f
,
T
out
,
Ts
...
xs
)
{
{
idx
.
global_stride
(
out
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
out
[
i
]
=
f
(
xs
[
i
]...);
});
idx
.
global_stride
(
out
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
out
[
i
]
=
implicit_conversion
(
f
(
xs
[
i
]...));
});
}
}
template
<
class
...
Transforms
>
template
<
class
...
Transforms
>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
63ed48e4
...
@@ -244,8 +244,9 @@ struct reducer_base
...
@@ -244,8 +244,9 @@ struct reducer_base
{
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
t
=
derived
.
slice
(
x
);
auto
t
=
derived
.
slice
(
x
);
return
make_storage_access
<
typename
decltype
(
t
)
::
type
>
(
return
make_storage_access
<
typename
decltype
(
t
)
::
type
>
([
=
](
auto
i
,
auto
...)
->
auto
&
{
[
=
](
auto
i
,
auto
...)
->
auto
&
{
return
t
[
i
];
});
return
t
[
i
];
});
}
}
}
}
...
...
test/py/test_shape.py
View file @
63ed48e4
...
@@ -81,7 +81,7 @@ def test_create_dyn_shape():
...
@@ -81,7 +81,7 @@ def test_create_dyn_shape():
def
test_type_enum
():
def
test_type_enum
():
mgx_types
=
[
mgx_types
=
[
'bool_type'
,
'double_type'
,
'float_type'
,
'half_type'
,
'float_type'
,
'int16_type'
,
'bool_type'
,
'double_type'
,
'float_type'
,
'half_type'
,
'float
8
_type'
,
'int16_type'
,
'int32_type'
,
'int64_type'
,
'int8_type'
,
'uint16_type'
,
'uint32_type'
,
'int32_type'
,
'int64_type'
,
'int8_type'
,
'uint16_type'
,
'uint32_type'
,
'uint64_type'
,
'uint8_type'
'uint64_type'
,
'uint8_type'
]
]
...
...
test/verify/test_literal_limits.cpp
View file @
63ed48e4
...
@@ -22,10 +22,10 @@
...
@@ -22,10 +22,10 @@
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include "migraphx/migraphx_float8.hpp"
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <limits>
#include <limits>
template
<
migraphx
::
shape
::
type_t
Q
,
typename
T
>
template
<
migraphx
::
shape
::
type_t
Q
,
typename
T
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment