Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9dc0b779
Unverified
Commit
9dc0b779
authored
Sep 12, 2022
by
Paul Fultz II
Committed by
GitHub
Sep 12, 2022
Browse files
Merge branch 'develop' into jit-concat
parents
30cde6df
d78bcdfb
Changes
146
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
78 additions
and
46 deletions
+78
-46
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+1
-1
src/include/migraphx/op/concat.hpp
src/include/migraphx/op/concat.hpp
+1
-1
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+3
-2
src/include/migraphx/op/fmod.hpp
src/include/migraphx/op/fmod.hpp
+0
-9
src/include/migraphx/op/gather.hpp
src/include/migraphx/op/gather.hpp
+1
-1
src/include/migraphx/op/mod.hpp
src/include/migraphx/op/mod.hpp
+0
-9
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+1
-1
src/include/migraphx/op/quant_dot.hpp
src/include/migraphx/op/quant_dot.hpp
+3
-2
src/include/migraphx/op/slice.hpp
src/include/migraphx/op/slice.hpp
+2
-2
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+1
-1
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+3
-3
src/include/migraphx/pass.hpp
src/include/migraphx/pass.hpp
+2
-2
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+1
-1
src/include/migraphx/raw_data.hpp
src/include/migraphx/raw_data.hpp
+4
-4
src/include/migraphx/reflect.hpp
src/include/migraphx/reflect.hpp
+1
-1
src/include/migraphx/requires.hpp
src/include/migraphx/requires.hpp
+1
-1
src/include/migraphx/rewrite_gelu.hpp
src/include/migraphx/rewrite_gelu.hpp
+48
-0
src/include/migraphx/schedule_model.hpp
src/include/migraphx/schedule_model.hpp
+2
-2
src/include/migraphx/stream_model.hpp
src/include/migraphx/stream_model.hpp
+2
-2
src/include/migraphx/streamutils.hpp
src/include/migraphx/streamutils.hpp
+1
-1
No files found.
src/include/migraphx/op/broadcast.hpp
View file @
9dc0b779
...
...
@@ -70,7 +70,7 @@ struct broadcast
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than input ndims"
);
}
if
(
!
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
if
(
not
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
...
...
src/include/migraphx/op/concat.hpp
View file @
9dc0b779
...
...
@@ -86,7 +86,7 @@ struct concat
{
if
(
l
!=
axis
)
{
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
return
s
.
lens
()[
l
]
==
first_shape_lens
[
l
];
}))
{
...
...
src/include/migraphx/op/dot.hpp
View file @
9dc0b779
...
...
@@ -43,13 +43,14 @@ struct dot
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
{
MIGRAPHX_THROW
(
"DOT: dot only accept 2 or more dims operands"
);
}
// only handle the case that the batch size of a and b are the same
if
(
!
std
::
equal
(
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
MIGRAPHX_THROW
(
"DOT: batch size of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
...
...
src/include/migraphx/op/fmod.hpp
View file @
9dc0b779
...
...
@@ -24,17 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <type_traits>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/gather.hpp
View file @
9dc0b779
...
...
@@ -65,7 +65,7 @@ struct gather
auto
lens
=
inputs
[
0
].
lens
();
auto
type
=
inputs
[
0
].
type
();
lens
.
erase
(
lens
.
begin
()
+
axis
);
if
(
!
inputs
[
1
].
scalar
())
if
(
not
inputs
[
1
].
scalar
())
{
auto
ind_lens
=
inputs
[
1
].
lens
();
lens
.
insert
(
lens
.
begin
()
+
axis
,
ind_lens
.
begin
(),
ind_lens
.
end
());
...
...
src/include/migraphx/op/mod.hpp
View file @
9dc0b779
...
...
@@ -24,17 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <type_traits>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
9dc0b779
...
...
@@ -266,7 +266,7 @@ struct nonmaxsuppression
auto
boxes_heap
=
filter_boxes_by_score
(
scores_start
,
num_boxes
,
score_threshold
);
selected_boxes_inside_class
.
clear
();
// Get the next box with top score, filter by iou_threshold
while
(
!
boxes_heap
.
empty
()
&&
while
(
not
boxes_heap
.
empty
()
&&
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
{
// Check with existing selected boxes for this class, remove box if it
...
...
src/include/migraphx/op/quant_dot.hpp
View file @
9dc0b779
...
...
@@ -49,13 +49,14 @@ struct quant_dot
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t"
);
}
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
{
MIGRAPHX_THROW
(
"QUANT_DOT: dot only accept 2 or more dims operands"
);
}
// only handle the case that the batch size of a and b are the same
if
(
!
std
::
equal
(
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
MIGRAPHX_THROW
(
"QUANT_DOT: batch size of A and B mismatch: {"
+
...
...
src/include/migraphx/op/slice.hpp
View file @
9dc0b779
...
...
@@ -78,7 +78,7 @@ struct slice
const
std
::
vector
<
std
::
size_t
>&
lens
=
s
.
lens
();
const
std
::
vector
<
std
::
size_t
>&
strides
=
s
.
strides
();
auto
offset
=
0
;
if
(
!
axes
.
empty
())
if
(
not
axes
.
empty
())
{
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
...
...
@@ -109,7 +109,7 @@ struct slice
MIGRAPHX_THROW
(
"SLICE: input axis "
+
to_string_range
(
axes
)
+
" out of range"
);
}
if
(
starts
.
size
()
!=
axes
.
size
()
||
axes
.
size
()
!=
ends
.
size
())
if
(
starts
.
size
()
!=
axes
.
size
()
or
axes
.
size
()
!=
ends
.
size
())
{
MIGRAPHX_THROW
(
"SLICE: inconsistent sizes"
);
}
...
...
src/include/migraphx/op/transpose.hpp
View file @
9dc0b779
...
...
@@ -59,7 +59,7 @@ struct transpose
}
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
!
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
if
(
not
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
{
MIGRAPHX_THROW
(
"TRANSPOSE: Invalid permutation"
);
}
...
...
src/include/migraphx/operation.hpp
View file @
9dc0b779
...
...
@@ -1066,7 +1066,7 @@ struct operation
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -1237,7 +1237,7 @@ struct operation
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
@@ -1276,7 +1276,7 @@ inline const ValueType& any_cast(const operation& x)
}
#endif
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
!
(
x
==
y
);
}
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
not
(
x
==
y
);
}
inline
value
compile
(
operation
&
op
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
...
...
src/include/migraphx/pass.hpp
View file @
9dc0b779
...
...
@@ -238,7 +238,7 @@ struct pass
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -292,7 +292,7 @@ struct pass
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/program.hpp
View file @
9dc0b779
...
...
@@ -124,7 +124,7 @@ struct program
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
not
(
x
==
y
);
}
// module related api
module
*
create_module
(
const
std
::
string
&
name
);
...
...
src/include/migraphx/raw_data.hpp
View file @
9dc0b779
...
...
@@ -147,7 +147,7 @@ struct raw_data : raw_data_base
template
<
class
T
>
bool
matches
()
const
{
return
is_data_ptr
<
T
>
{}
||
return
is_data_ptr
<
T
>
{}
or
self
->
get_shape
().
type
()
==
migraphx
::
shape
::
get_type
<
get_data_type
<
T
>>
{};
}
...
...
@@ -232,7 +232,7 @@ auto visit_all(T&& x, Ts&&... xs)
{
auto
&&
s
=
x
.
get_shape
();
std
::
initializer_list
<
shape
::
type_t
>
types
=
{
xs
.
get_shape
().
type
()...};
if
(
!
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
if
(
not
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
return
[
&
](
auto
...
vs
)
{
detail
::
visit_all_pack
(
s
,
vs
...)(
x
,
xs
...);
};
}
...
...
@@ -241,7 +241,7 @@ template <class T>
auto
visit_all
(
const
std
::
vector
<
T
>&
x
)
{
auto
&&
s
=
x
.
front
().
get_shape
();
if
(
!
std
::
all_of
(
if
(
not
std
::
all_of
(
x
.
begin
(),
x
.
end
(),
[
&
](
const
T
&
y
)
{
return
y
.
get_shape
().
type
()
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
return
[
&
](
auto
v
)
{
...
...
@@ -281,7 +281,7 @@ template <class T,
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
bool
operator
!=
(
const
T
&
x
,
const
U
&
y
)
{
return
!
(
x
==
y
);
return
not
(
x
==
y
);
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/reflect.hpp
View file @
9dc0b779
...
...
@@ -129,7 +129,7 @@ template <class T>
struct
reflect_equality
{
friend
bool
operator
==
(
const
T
&
x
,
const
T
&
y
)
{
return
reflect_tie
(
x
)
==
reflect_tie
(
y
);
}
friend
bool
operator
!=
(
const
T
&
x
,
const
T
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
T
&
x
,
const
T
&
y
)
{
return
not
(
x
==
y
);
}
};
template
<
class
T
>
...
...
src/include/migraphx/requires.hpp
View file @
9dc0b779
...
...
@@ -31,7 +31,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
bool
...
Bs
>
struct
and_
:
std
::
is_same
<
and_
<
Bs
...
>
,
and_
<
(
Bs
||
true
)...
>>
// NOLINT
struct
and_
:
std
::
is_same
<
and_
<
Bs
...
>
,
and_
<
(
Bs
or
true
)...
>>
// NOLINT
{
};
...
...
src/include/migraphx/rewrite_gelu.hpp
0 → 100644
View file @
9dc0b779
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
/**
* Rewrite gelu standard formula as the sigmoid approximation formula
*/
struct
rewrite_gelu
{
std
::
string
name
()
const
{
return
"rewrite_gelu"
;
}
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/schedule_model.hpp
View file @
9dc0b779
...
...
@@ -208,7 +208,7 @@ struct schedule_model
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -274,7 +274,7 @@ struct schedule_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/stream_model.hpp
View file @
9dc0b779
...
...
@@ -216,7 +216,7 @@ struct stream_model
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -288,7 +288,7 @@ struct stream_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/streamutils.hpp
View file @
9dc0b779
...
...
@@ -41,7 +41,7 @@ struct stream_range_container
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
stream_range_container
&
sr
)
{
assert
(
sr
.
r
!=
nullptr
);
if
(
!
sr
.
r
->
empty
())
if
(
not
sr
.
r
->
empty
())
{
os
<<
sr
.
r
->
front
();
std
::
for_each
(
...
...
Prev
1
2
3
4
5
6
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment