Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3a4d36cf
"src/targets/vscode:/vscode.git/clone" did not exist on "c46b0432184379c8fc6349f1febdbbe2b0a276ef"
Commit
3a4d36cf
authored
Sep 30, 2022
by
charlie
Browse files
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test
parents
6bec381f
e19f78ae
Changes
384
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
170 additions
and
107 deletions
+170
-107
src/include/migraphx/allocation_model.hpp
src/include/migraphx/allocation_model.hpp
+2
-2
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+16
-15
src/include/migraphx/concat_opt.hpp
src/include/migraphx/concat_opt.hpp
+2
-2
src/include/migraphx/context.hpp
src/include/migraphx/context.hpp
+2
-2
src/include/migraphx/iterator.hpp
src/include/migraphx/iterator.hpp
+2
-2
src/include/migraphx/make_op.hpp
src/include/migraphx/make_op.hpp
+4
-0
src/include/migraphx/marker.hpp
src/include/migraphx/marker.hpp
+2
-2
src/include/migraphx/match/gelu_erf.hpp
src/include/migraphx/match/gelu_erf.hpp
+5
-5
src/include/migraphx/match/layernorm.hpp
src/include/migraphx/match/layernorm.hpp
+2
-2
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+6
-2
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+1
-1
src/include/migraphx/onnx.hpp
src/include/migraphx/onnx.hpp
+6
-8
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+1
-1
src/include/migraphx/op/concat.hpp
src/include/migraphx/op/concat.hpp
+1
-1
src/include/migraphx/op/convert.hpp
src/include/migraphx/op/convert.hpp
+9
-1
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+3
-2
src/include/migraphx/op/fmod.hpp
src/include/migraphx/op/fmod.hpp
+0
-10
src/include/migraphx/op/gather.hpp
src/include/migraphx/op/gather.hpp
+1
-1
src/include/migraphx/op/mod.hpp
src/include/migraphx/op/mod.hpp
+1
-10
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+104
-38
No files found.
src/include/migraphx/allocation_model.hpp
View file @
3a4d36cf
...
...
@@ -205,7 +205,7 @@ struct allocation_model
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -267,7 +267,7 @@ struct allocation_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/check_shapes.hpp
View file @
3a4d36cf
...
...
@@ -101,7 +101,7 @@ struct check_shapes
const
check_shapes
&
nelements
(
std
::
size_t
n
)
const
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
if
(
not
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes must have only "
+
std
::
to_string
(
n
)
+
" elements"
);
return
*
this
;
}
...
...
@@ -164,7 +164,7 @@ struct check_shapes
*/
const
check_shapes
&
same_shape
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes do not match"
);
return
*
this
;
}
...
...
@@ -174,7 +174,7 @@ struct check_shapes
*/
const
check_shapes
&
same_type
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Types do not match"
);
return
*
this
;
}
...
...
@@ -184,10 +184,10 @@ struct check_shapes
*/
const
check_shapes
&
same_dims
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Dimensions do not match"
);
if
(
this
->
any_of
([
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
min_lens
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
min_lens
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Min dynamic dimensions do not match"
);
return
*
this
;
}
...
...
@@ -197,7 +197,7 @@ struct check_shapes
*/
const
check_shapes
&
same_ndims
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Number of dimensions do not match"
);
return
*
this
;
}
...
...
@@ -207,7 +207,7 @@ struct check_shapes
*/
const
check_shapes
&
standard
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not in standard layout"
);
return
*
this
;
}
...
...
@@ -217,7 +217,7 @@ struct check_shapes
*/
const
check_shapes
&
standard_or_scalar
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar or in standard layout"
);
return
*
this
;
}
...
...
@@ -227,7 +227,7 @@ struct check_shapes
*/
const
check_shapes
&
packed
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed"
);
return
*
this
;
}
...
...
@@ -237,7 +237,7 @@ struct check_shapes
*/
const
check_shapes
&
packed_or_broadcasted
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
()
or
s
.
broadcasted
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
()
or
s
.
broadcasted
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed nor broadcasted"
);
return
*
this
;
}
...
...
@@ -247,7 +247,7 @@ struct check_shapes
*/
const
check_shapes
&
tuple_type
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
type
()
==
shape
::
tuple_type
;
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
type
()
==
shape
::
tuple_type
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not tuple!"
);
return
*
this
;
}
...
...
@@ -257,7 +257,7 @@ struct check_shapes
*/
const
check_shapes
&
not_transposed
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
transposed
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
transposed
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are transposed"
);
return
*
this
;
}
...
...
@@ -267,7 +267,7 @@ struct check_shapes
*/
const
check_shapes
&
not_broadcasted
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
broadcasted
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
broadcasted
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are broadcasted"
);
return
*
this
;
}
...
...
@@ -278,7 +278,7 @@ struct check_shapes
*/
const
check_shapes
&
elements
(
std
::
size_t
n
)
const
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
if
(
not
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Wrong number of elements"
);
return
*
this
;
}
...
...
@@ -288,7 +288,8 @@ struct check_shapes
*/
const
check_shapes
&
batch_not_transposed
()
const
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
batch_not_transposed_strides
(
s
.
strides
());
}))
if
(
not
this
->
all_of
(
[
&
](
const
shape
&
s
)
{
return
batch_not_transposed_strides
(
s
.
strides
());
}))
MIGRAPHX_THROW
(
prefix
()
+
"Batch size is transposed"
);
return
*
this
;
}
...
...
src/include/migraphx/concat_opt.hpp
View file @
3a4d36cf
...
...
@@ -183,7 +183,7 @@ struct concat_optimization
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -233,7 +233,7 @@ struct concat_optimization
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/context.hpp
View file @
3a4d36cf
...
...
@@ -246,7 +246,7 @@ struct context
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -306,7 +306,7 @@ struct context
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/iterator.hpp
View file @
3a4d36cf
...
...
@@ -31,9 +31,9 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Iterator
,
class
EndIterator
>
auto
is_end
(
rank
<
2
>
,
Iterator
it
,
EndIterator
)
->
decltype
(
!
it
.
_M_dereferenceable
())
auto
is_end
(
rank
<
2
>
,
Iterator
it
,
EndIterator
)
->
decltype
(
not
it
.
_M_dereferenceable
())
{
return
!
it
.
_M_dereferenceable
();
return
not
it
.
_M_dereferenceable
();
}
template
<
class
Iterator
,
class
EndIterator
>
...
...
src/include/migraphx/make_op.hpp
View file @
3a4d36cf
...
...
@@ -27,6 +27,8 @@
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -46,6 +48,8 @@ operation make_op(const std::string& name, const Value& v)
return
make_op_from_value
(
name
,
v
);
}
operation
make_json_op
(
const
std
::
string
&
name
,
const
std
::
string
&
s
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/marker.hpp
View file @
3a4d36cf
...
...
@@ -181,7 +181,7 @@ struct marker
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -233,7 +233,7 @@ struct marker
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/match/gelu_erf.hpp
View file @
3a4d36cf
...
...
@@ -38,11 +38,11 @@ struct gelu_erf_matcher
F
f
;
auto
erf_fn
()
const
{
return
f
(
"erf"
)(
used_once
(),
arg
(
0
)(
used_once
(),
f
(
"mul"
)(
either_arg
(
0
,
1
)
(
none_of
(
has_value
(
M_SQRT
1_
2
,
1e-3
)).
bind
(
"x"
),
has_value
(
M_SQRT1_2
,
1e-3
))
)));
auto
mul_1_sqrt_2
=
f
(
"mul"
)(
either_arg
(
0
,
1
)(
none_of
(
has_value
(
M_SQRT1_2
,
1e-3
)).
bind
(
"x"
),
has_value
(
M_SQRT1_2
,
1e-3
)));
auto
div_sqrt_2
=
f
(
"div"
)(
args
(
none_of
(
has_value
(
M_SQRT2
,
1e-3
)).
bind
(
"x"
),
has_value
(
M_SQRT2
,
1e-3
)));
return
f
(
"erf"
)(
used_once
(),
arg
(
0
)(
used_once
(),
any_of
(
mul_1_sqrt_2
,
div_sqrt_2
)));
}
auto
add_erf
()
const
...
...
src/include/migraphx/match/layernorm.hpp
View file @
3a4d36cf
...
...
@@ -50,8 +50,8 @@ struct layernorm_matcher
{
return
f
(
"div"
)(
arg
(
0
)(
x_minus_mean
()),
arg
(
1
)(
skip_broadcasts
(
f
(
"sqrt"
)(
arg
(
0
)(
f
(
"add"
)(
either_arg
(
0
,
1
)(
variance
(),
has_value
(
1e-12
f
))))))));
arg
(
1
)(
skip_broadcasts
(
f
(
"sqrt"
)(
arg
(
0
)(
f
(
"add"
)(
either_arg
(
0
,
1
)(
variance
(),
is_constant
().
bind
(
"eps"
))))))));
}
auto
matcher
()
const
{
return
layernorm_onnx
();
}
...
...
src/include/migraphx/matcher.hpp
View file @
3a4d36cf
...
...
@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in
return
nullopt
;
}
MIGRAPHX_PRED_MATCHER
(
broadcast
,
instruction_ref
ins
)
{
return
contains
({
"broadcast"
,
"multibroadcast"
},
ins
->
name
());
}
template
<
class
...
Ms
>
auto
skip
(
Ms
...
ms
)
{
...
...
@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name)
template
<
class
...
Ms
>
auto
pointwise
(
Ms
...
ms
)
{
return
match
::
has_attribute
(
"pointwise"
)(
match
::
any_of
(
match
::
nargs
(
1
),
match
::
nargs
(
2
)),
ms
...);
return
match
::
has_attribute
(
"pointwise"
)(
ms
...);
}
}
// namespace match
...
...
src/include/migraphx/module.hpp
View file @
3a4d36cf
...
...
@@ -219,7 +219,7 @@ struct module
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
module
&
m
);
friend
bool
operator
==
(
const
module
&
x
,
const
module
&
y
);
friend
bool
operator
!=
(
const
module
&
x
,
const
module
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
module
&
x
,
const
module
&
y
)
{
return
not
(
x
==
y
);
}
private:
void
assign
(
const
module
&
m
);
...
...
src/include/migraphx/onnx.hpp
View file @
3a4d36cf
...
...
@@ -35,17 +35,13 @@ struct onnx_options
{
/// Old way to set default fixed dimension size
std
::
size_t
default_dim_value
=
0
;
/*!
* Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value
* set parser throws)
*/
/// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set
/// parser throws)
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
,
0
};
/// Explicitly specify the dims of an input
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
=
{};
/*!
* Explicitly specify dynamic dims of an input (if both map_input_dims and
* map_dyn_input_dims set parser throws)
*/
/// Explicitly specify dynamic dims of an input (if both map_input_dims and map_dyn_input_dims
/// set parser throws)
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
shape
::
dynamic_dimension
>>
map_dyn_input_dims
=
{};
/// Continue parsing onnx file if an unknown operator is found
bool
skip_unknown_operators
=
false
;
...
...
@@ -53,6 +49,8 @@ struct onnx_options
bool
print_program_on_error
=
false
;
/// Max iter num for the loop operator
int64_t
max_loop_iterations
=
10
;
/// Use dynamic output for operators when available
bool
use_dyn_output
=
false
;
};
/// Create a program from an onnx file
...
...
src/include/migraphx/op/broadcast.hpp
View file @
3a4d36cf
...
...
@@ -70,7 +70,7 @@ struct broadcast
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than input ndims"
);
}
if
(
!
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
if
(
not
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
...
...
src/include/migraphx/op/concat.hpp
View file @
3a4d36cf
...
...
@@ -86,7 +86,7 @@ struct concat
{
if
(
l
!=
axis
)
{
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
return
s
.
lens
()[
l
]
==
first_shape_lens
[
l
];
}))
{
...
...
src/include/migraphx/op/convert.hpp
View file @
3a4d36cf
...
...
@@ -45,7 +45,15 @@ struct convert : unary<convert>
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
{
target_type
,
inputs
.
at
(
0
).
lens
(),
inputs
.
at
(
0
).
strides
()};
auto
input
=
inputs
.
at
(
0
);
if
(
input
.
dynamic
())
{
return
{
target_type
,
input
.
dyn_dims
()};
}
else
{
return
{
target_type
,
input
.
lens
(),
input
.
strides
()};
}
}
std
::
string
point_op
()
const
...
...
src/include/migraphx/op/dot.hpp
View file @
3a4d36cf
...
...
@@ -43,13 +43,14 @@ struct dot
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
{
MIGRAPHX_THROW
(
"DOT: dot only accept 2 or more dims operands"
);
}
// only handle the case that the batch size of a and b are the same
if
(
!
std
::
equal
(
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
MIGRAPHX_THROW
(
"DOT: batch size of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
...
...
src/include/migraphx/op/fmod.hpp
View file @
3a4d36cf
...
...
@@ -24,17 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <type_traits>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -49,7 +40,6 @@ struct fmod : binary<fmod>
a
[
"commutative"
]
=
false
;
return
a
;
}
std
::
string
point_function
()
const
{
return
"fmod"
;
}
auto
apply
()
const
{
return
[](
auto
x
,
auto
y
)
{
return
std
::
fmod
(
x
,
y
);
};
...
...
src/include/migraphx/op/gather.hpp
View file @
3a4d36cf
...
...
@@ -65,7 +65,7 @@ struct gather
auto
lens
=
inputs
[
0
].
lens
();
auto
type
=
inputs
[
0
].
type
();
lens
.
erase
(
lens
.
begin
()
+
axis
);
if
(
!
inputs
[
1
].
scalar
())
if
(
not
inputs
[
1
].
scalar
())
{
auto
ind_lens
=
inputs
[
1
].
lens
();
lens
.
insert
(
lens
.
begin
()
+
axis
,
ind_lens
.
begin
(),
ind_lens
.
end
());
...
...
src/include/migraphx/op/mod.hpp
View file @
3a4d36cf
...
...
@@ -24,17 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <type_traits>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -47,9 +38,9 @@ struct mod : binary<mod>
{
auto
a
=
base_attributes
();
a
[
"commutative"
]
=
false
;
a
[
"point_op"
]
=
"${function:fmod}((${function:remainder}(${0}, ${1})) + ${1}, ${1})"
;
return
a
;
}
std
::
string
point_function
()
const
{
return
"mod"
;
}
auto
apply
()
const
{
return
[](
auto
x
,
auto
y
)
{
return
std
::
fmod
((
std
::
remainder
(
x
,
y
))
+
y
,
y
);
};
...
...
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
3a4d36cf
...
...
@@ -45,11 +45,13 @@ namespace op {
struct
nonmaxsuppression
{
bool
center_point_box
=
false
;
bool
use_dyn_output
=
false
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
center_point_box
,
"center_point_box"
));
return
pack
(
f
(
self
.
center_point_box
,
"center_point_box"
),
f
(
self
.
use_dyn_output
,
"use_dyn_output"
));
}
std
::
string
name
()
const
{
return
"nonmaxsuppression"
;
}
...
...
@@ -57,27 +59,81 @@ struct nonmaxsuppression
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
// requires at least 2 inputs
check_shapes
{{
inputs
.
at
(
0
),
inputs
.
at
(
1
)},
*
this
}.
only_dims
(
3
);
auto
lens
=
inputs
.
front
().
lens
();
check_shapes
{{
inputs
.
at
(
0
),
inputs
.
at
(
1
)},
*
this
,
true
}.
only_dims
(
3
).
same_ndims
();
auto
boxes_max_lens
=
inputs
.
at
(
0
).
max_lens
();
// num batches * num boxes
const
auto
max_num_boxes
=
boxes_max_lens
.
at
(
0
)
*
boxes_max_lens
.
at
(
1
);
// check input shape
if
(
lens
[
1
]
!=
inputs
.
at
(
1
).
lens
()[
2
])
auto
fixed_shape_error_check
=
[
&
]()
{
auto
lens
=
inputs
.
front
().
lens
();
if
(
lens
[
1
]
!=
inputs
.
at
(
1
).
lens
()[
2
])
{
MIGRAPHX_THROW
(
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input"
);
}
if
(
lens
[
0
]
!=
inputs
.
at
(
1
).
lens
()[
0
])
{
MIGRAPHX_THROW
(
"NonMaxSuppression: number of batches mismatch between boxes and scores input"
);
}
};
if
(
use_dyn_output
)
{
MIGRAPHX_THROW
(
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input"
);
if
(
inputs
.
at
(
0
).
dynamic
())
{
// both boxes and scores should be dynamic
// check dynamic dimensions are consistent
const
auto
boxes_dims
=
inputs
.
at
(
0
).
dyn_dims
();
const
auto
scores_dims
=
inputs
.
at
(
1
).
dyn_dims
();
if
(
boxes_dims
.
at
(
1
)
!=
scores_dims
.
at
(
2
))
{
MIGRAPHX_THROW
(
"NonMaxSuppression: dynamic spatial dimension mismatch between "
"boxes and scores input"
);
}
if
(
boxes_dims
.
at
(
0
)
!=
scores_dims
.
at
(
0
))
{
MIGRAPHX_THROW
(
"NonMaxSuppression: dynamic number of batches mismatch between "
"boxes and scores input"
);
}
}
else
if
(
inputs
.
at
(
1
).
dynamic
())
{
// scores has dynamic shape, boxes fixed shape
// check that it is only a dynamic number of classes
const
auto
scores_dims
=
inputs
.
at
(
1
).
dyn_dims
();
const
auto
boxes_lens
=
inputs
.
at
(
0
).
lens
();
if
(
not
scores_dims
.
at
(
0
).
is_fixed
()
or
scores_dims
.
at
(
0
).
max
!=
boxes_lens
.
at
(
0
))
{
MIGRAPHX_THROW
(
"NonMaxSuppression: scores dynamic num_classes; num_batches not "
"fixed or mismatched"
);
}
if
(
not
scores_dims
.
at
(
2
).
is_fixed
()
or
scores_dims
.
at
(
2
).
max
!=
boxes_lens
.
at
(
1
))
{
MIGRAPHX_THROW
(
"NonMaxSuppression: scores dynamic num_classes; "
"spatial_dimension not fixed or mismatches"
);
}
}
else
{
fixed_shape_error_check
();
}
std
::
vector
<
shape
::
dynamic_dimension
>
out_lens
=
{};
out_lens
.
push_back
({
0
,
max_num_boxes
,
0
});
out_lens
.
push_back
({
3
,
3
,
0
});
return
{
shape
::
int64_type
,
out_lens
};
}
// check batch sizes
if
(
lens
[
0
]
!=
inputs
.
at
(
1
).
lens
()[
0
])
else
{
MIGRAPHX_THROW
(
"NonMaxSuppression: number of batches mismatch between boxes and scores input"
);
if
(
inputs
.
at
(
0
).
dynamic
()
or
inputs
.
at
(
1
).
dynamic
())
{
MIGRAPHX_THROW
(
"NonMaxSuppression: dynamic input shape with use_dyn_output set to false"
);
}
fixed_shape_error_check
();
std
::
vector
<
std
::
size_t
>
out_lens
=
{
max_num_boxes
,
3
};
return
{
shape
::
int64_type
,
out_lens
};
}
std
::
vector
<
int64_t
>
out_lens
(
2
);
out_lens
.
at
(
0
)
=
lens
.
at
(
1
);
out_lens
.
at
(
1
)
=
3
;
return
{
shape
::
int64_type
,
out_lens
};
}
struct
box
...
...
@@ -181,13 +237,13 @@ struct nonmaxsuppression
}
template
<
class
Output
,
class
Boxes
,
class
Scores
>
void
compute_nms
(
Output
output
,
Boxes
boxes
,
Scores
scores
,
const
shape
&
output_shape
,
std
::
size_t
max_output_boxes_per_class
,
double
iou_threshold
,
double
score_threshold
)
const
std
::
size_t
compute_nms
(
Output
output
,
Boxes
boxes
,
Scores
scores
,
const
shape
&
max_
output_shape
,
std
::
size_t
max_output_boxes_per_class
,
double
iou_threshold
,
double
score_threshold
)
const
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
const
auto
&
lens
=
scores
.
get_shape
().
lens
();
...
...
@@ -197,7 +253,7 @@ struct nonmaxsuppression
// boxes of a class with NMS applied [score, index]
std
::
vector
<
std
::
pair
<
double
,
int64_t
>>
selected_boxes_inside_class
;
std
::
vector
<
int64_t
>
selected_indices
;
selected_boxes_inside_class
.
reserve
(
output_shape
.
elements
());
selected_boxes_inside_class
.
reserve
(
max_
output_shape
.
elements
());
// iterate over batches and classes
shape
comp_s
{
shape
::
double_type
,
{
num_batches
,
num_classes
}};
shape_for_each
(
comp_s
,
[
&
](
auto
idx
)
{
...
...
@@ -210,7 +266,7 @@ struct nonmaxsuppression
auto
boxes_heap
=
filter_boxes_by_score
(
scores_start
,
num_boxes
,
score_threshold
);
selected_boxes_inside_class
.
clear
();
// Get the next box with top score, filter by iou_threshold
while
(
!
boxes_heap
.
empty
()
&&
while
(
not
boxes_heap
.
empty
()
&&
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
{
// Check with existing selected boxes for this class, remove box if it
...
...
@@ -237,11 +293,14 @@ struct nonmaxsuppression
}
});
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
return
selected_indices
.
size
()
/
3
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
// make buffer of maximum size
shape
max_output_shape
=
{
output_shape
.
type
(),
output_shape
.
max_lens
()};
argument
result
{
max_output_shape
};
std
::
size_t
max_output_boxes_per_class
=
(
args
.
size
()
>
2
)
?
(
args
.
at
(
2
).
at
<
std
::
size_t
>
())
:
0
;
...
...
@@ -249,22 +308,29 @@ struct nonmaxsuppression
{
return
result
;
}
double
iou_threshold
=
(
args
.
size
()
>
3
)
?
(
args
.
at
(
3
).
at
<
double
>
())
:
0.0
f
;
double
score_threshold
=
(
args
.
size
()
>
4
)
?
(
args
.
at
(
4
).
at
<
double
>
())
:
0.0
f
;
double
iou_threshold
=
(
args
.
size
()
>
3
)
?
(
args
.
at
(
3
).
at
<
double
>
())
:
0.0
f
;
double
score_threshold
=
(
args
.
size
()
>
4
)
?
(
args
.
at
(
4
).
at
<
double
>
())
:
0.0
f
;
std
::
size_t
num_selected
=
0
;
result
.
visit
([
&
](
auto
output
)
{
visit_all
(
args
[
0
],
args
[
1
])([
&
](
auto
boxes
,
auto
scores
)
{
compute_nms
(
output
,
boxes
,
scores
,
output_shape
,
max_output_boxes_per_class
,
iou_threshold
,
score_threshold
);
num_selected
=
compute_nms
(
output
,
boxes
,
scores
,
max_
output_shape
,
max_output_boxes_per_class
,
iou_threshold
,
score_threshold
);
});
});
return
result
;
if
(
use_dyn_output
)
{
return
result
.
reshape
({
output_shape
.
type
(),
{
num_selected
,
3
}});
}
else
{
return
result
;
}
}
};
...
...
Prev
1
2
3
4
5
6
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment