Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5ec8f913
Commit
5ec8f913
authored
Sep 13, 2022
by
Ted Themistokleous
Committed by
Ted Themistokleous
Sep 13, 2022
Browse files
Merge branch 'develop' into simplify_1_mul_div_ops
parents
32d69e8e
d78bcdfb
Changes
183
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
156 additions
and
92 deletions
+156
-92
src/include/migraphx/iterator.hpp
src/include/migraphx/iterator.hpp
+2
-2
src/include/migraphx/make_op.hpp
src/include/migraphx/make_op.hpp
+4
-0
src/include/migraphx/marker.hpp
src/include/migraphx/marker.hpp
+2
-2
src/include/migraphx/match/gelu_erf.hpp
src/include/migraphx/match/gelu_erf.hpp
+5
-5
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+6
-2
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+1
-1
src/include/migraphx/onnx.hpp
src/include/migraphx/onnx.hpp
+6
-8
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+1
-1
src/include/migraphx/op/concat.hpp
src/include/migraphx/op/concat.hpp
+1
-1
src/include/migraphx/op/convert.hpp
src/include/migraphx/op/convert.hpp
+9
-1
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+3
-2
src/include/migraphx/op/fmod.hpp
src/include/migraphx/op/fmod.hpp
+0
-9
src/include/migraphx/op/gather.hpp
src/include/migraphx/op/gather.hpp
+1
-1
src/include/migraphx/op/mod.hpp
src/include/migraphx/op/mod.hpp
+0
-9
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+104
-38
src/include/migraphx/op/quant_dot.hpp
src/include/migraphx/op/quant_dot.hpp
+3
-2
src/include/migraphx/op/slice.hpp
src/include/migraphx/op/slice.hpp
+2
-2
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+1
-1
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+3
-3
src/include/migraphx/pass.hpp
src/include/migraphx/pass.hpp
+2
-2
No files found.
src/include/migraphx/iterator.hpp
View file @
5ec8f913
...
...
@@ -31,9 +31,9 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Iterator
,
class
EndIterator
>
auto
is_end
(
rank
<
2
>
,
Iterator
it
,
EndIterator
)
->
decltype
(
!
it
.
_M_dereferenceable
())
auto
is_end
(
rank
<
2
>
,
Iterator
it
,
EndIterator
)
->
decltype
(
not
it
.
_M_dereferenceable
())
{
return
!
it
.
_M_dereferenceable
();
return
not
it
.
_M_dereferenceable
();
}
template
<
class
Iterator
,
class
EndIterator
>
...
...
src/include/migraphx/make_op.hpp
View file @
5ec8f913
...
...
@@ -27,6 +27,8 @@
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -46,6 +48,8 @@ operation make_op(const std::string& name, const Value& v)
return
make_op_from_value
(
name
,
v
);
}
operation
make_json_op
(
const
std
::
string
&
name
,
const
std
::
string
&
s
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/marker.hpp
View file @
5ec8f913
...
...
@@ -181,7 +181,7 @@ struct marker
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -233,7 +233,7 @@ struct marker
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/match/gelu_erf.hpp
View file @
5ec8f913
...
...
@@ -38,11 +38,11 @@ struct gelu_erf_matcher
F
f
;
auto
erf_fn
()
const
{
return
f
(
"erf"
)(
used_once
(),
arg
(
0
)(
used_once
(),
f
(
"mul"
)(
either_arg
(
0
,
1
)
(
none_of
(
has_value
(
M_SQRT
1_
2
,
1e-3
)).
bind
(
"x"
),
has_value
(
M_SQRT1_2
,
1e-3
))
)));
auto
mul_1_sqrt_2
=
f
(
"mul"
)(
either_arg
(
0
,
1
)(
none_of
(
has_value
(
M_SQRT1_2
,
1e-3
)).
bind
(
"x"
),
has_value
(
M_SQRT1_2
,
1e-3
)));
auto
div_sqrt_2
=
f
(
"div"
)(
args
(
none_of
(
has_value
(
M_SQRT2
,
1e-3
)).
bind
(
"x"
),
has_value
(
M_SQRT2
,
1e-3
)));
return
f
(
"erf"
)(
used_once
(),
arg
(
0
)(
used_once
(),
any_of
(
mul_1_sqrt_2
,
div_sqrt_2
)));
}
auto
add_erf
()
const
...
...
src/include/migraphx/matcher.hpp
View file @
5ec8f913
...
...
@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in
return
nullopt
;
}
MIGRAPHX_PRED_MATCHER
(
broadcast
,
instruction_ref
ins
)
{
return
contains
({
"broadcast"
,
"multibroadcast"
},
ins
->
name
());
}
template
<
class
...
Ms
>
auto
skip
(
Ms
...
ms
)
{
...
...
@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name)
template
<
class
...
Ms
>
auto
pointwise
(
Ms
...
ms
)
{
return
match
::
has_attribute
(
"pointwise"
)(
match
::
any_of
(
match
::
nargs
(
1
),
match
::
nargs
(
2
)),
ms
...);
return
match
::
has_attribute
(
"pointwise"
)(
ms
...);
}
}
// namespace match
...
...
src/include/migraphx/module.hpp
View file @
5ec8f913
...
...
@@ -219,7 +219,7 @@ struct module
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
module
&
m
);
friend
bool
operator
==
(
const
module
&
x
,
const
module
&
y
);
friend
bool
operator
!=
(
const
module
&
x
,
const
module
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
module
&
x
,
const
module
&
y
)
{
return
not
(
x
==
y
);
}
private:
void
assign
(
const
module
&
m
);
...
...
src/include/migraphx/onnx.hpp
View file @
5ec8f913
...
...
@@ -35,17 +35,13 @@ struct onnx_options
{
/// Old way to set default fixed dimension size
std
::
size_t
default_dim_value
=
0
;
/*!
* Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value
* set parser throws)
*/
/// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set
/// parser throws)
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
,
0
};
/// Explicitly specify the dims of an input
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
=
{};
/*!
* Explicitly specify dynamic dims of an input (if both map_input_dims and
* map_dyn_input_dims set parser throws)
*/
/// Explicitly specify dynamic dims of an input (if both map_input_dims and map_dyn_input_dims
/// set parser throws)
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
shape
::
dynamic_dimension
>>
map_dyn_input_dims
=
{};
/// Continue parsing onnx file if an unknown operator is found
bool
skip_unknown_operators
=
false
;
...
...
@@ -53,6 +49,8 @@ struct onnx_options
bool
print_program_on_error
=
false
;
/// Max iter num for the loop operator
int64_t
max_loop_iterations
=
10
;
/// Use dynamic output for operators when available
bool
use_dyn_output
=
false
;
};
/// Create a program from an onnx file
...
...
src/include/migraphx/op/broadcast.hpp
View file @
5ec8f913
...
...
@@ -70,7 +70,7 @@ struct broadcast
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than input ndims"
);
}
if
(
!
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
if
(
not
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
...
...
src/include/migraphx/op/concat.hpp
View file @
5ec8f913
...
...
@@ -86,7 +86,7 @@ struct concat
{
if
(
l
!=
axis
)
{
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
s
)
{
return
s
.
lens
()[
l
]
==
first_shape_lens
[
l
];
}))
{
...
...
src/include/migraphx/op/convert.hpp
View file @
5ec8f913
...
...
@@ -45,7 +45,15 @@ struct convert : unary<convert>
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
{
target_type
,
inputs
.
at
(
0
).
lens
(),
inputs
.
at
(
0
).
strides
()};
auto
input
=
inputs
.
at
(
0
);
if
(
input
.
dynamic
())
{
return
{
target_type
,
input
.
dyn_dims
()};
}
else
{
return
{
target_type
,
input
.
lens
(),
input
.
strides
()};
}
}
std
::
string
point_op
()
const
...
...
src/include/migraphx/op/dot.hpp
View file @
5ec8f913
...
...
@@ -43,13 +43,14 @@ struct dot
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
{
MIGRAPHX_THROW
(
"DOT: dot only accept 2 or more dims operands"
);
}
// only handle the case that the batch size of a and b are the same
if
(
!
std
::
equal
(
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
MIGRAPHX_THROW
(
"DOT: batch size of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
...
...
src/include/migraphx/op/fmod.hpp
View file @
5ec8f913
...
...
@@ -24,17 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <type_traits>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/gather.hpp
View file @
5ec8f913
...
...
@@ -65,7 +65,7 @@ struct gather
auto
lens
=
inputs
[
0
].
lens
();
auto
type
=
inputs
[
0
].
type
();
lens
.
erase
(
lens
.
begin
()
+
axis
);
if
(
!
inputs
[
1
].
scalar
())
if
(
not
inputs
[
1
].
scalar
())
{
auto
ind_lens
=
inputs
[
1
].
lens
();
lens
.
insert
(
lens
.
begin
()
+
axis
,
ind_lens
.
begin
(),
ind_lens
.
end
());
...
...
src/include/migraphx/op/mod.hpp
View file @
5ec8f913
...
...
@@ -24,17 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <type_traits>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
5ec8f913
...
...
@@ -45,11 +45,13 @@ namespace op {
struct
nonmaxsuppression
{
bool
center_point_box
=
false
;
bool
use_dyn_output
=
false
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
center_point_box
,
"center_point_box"
));
return
pack
(
f
(
self
.
center_point_box
,
"center_point_box"
),
f
(
self
.
use_dyn_output
,
"use_dyn_output"
));
}
std
::
string
name
()
const
{
return
"nonmaxsuppression"
;
}
...
...
@@ -57,27 +59,81 @@ struct nonmaxsuppression
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
// requires at least 2 inputs
check_shapes
{{
inputs
.
at
(
0
),
inputs
.
at
(
1
)},
*
this
}.
only_dims
(
3
);
auto
lens
=
inputs
.
front
().
lens
();
check_shapes
{{
inputs
.
at
(
0
),
inputs
.
at
(
1
)},
*
this
,
true
}.
only_dims
(
3
).
same_ndims
();
auto
boxes_max_lens
=
inputs
.
at
(
0
).
max_lens
();
// num batches * num boxes
const
auto
max_num_boxes
=
boxes_max_lens
.
at
(
0
)
*
boxes_max_lens
.
at
(
1
);
// check input shape
if
(
lens
[
1
]
!=
inputs
.
at
(
1
).
lens
()[
2
])
auto
fixed_shape_error_check
=
[
&
]()
{
auto
lens
=
inputs
.
front
().
lens
();
if
(
lens
[
1
]
!=
inputs
.
at
(
1
).
lens
()[
2
])
{
MIGRAPHX_THROW
(
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input"
);
}
if
(
lens
[
0
]
!=
inputs
.
at
(
1
).
lens
()[
0
])
{
MIGRAPHX_THROW
(
"NonMaxSuppression: number of batches mismatch between boxes and scores input"
);
}
};
if
(
use_dyn_output
)
{
MIGRAPHX_THROW
(
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input"
);
if
(
inputs
.
at
(
0
).
dynamic
())
{
// both boxes and scores should be dynamic
// check dynamic dimensions are consistent
const
auto
boxes_dims
=
inputs
.
at
(
0
).
dyn_dims
();
const
auto
scores_dims
=
inputs
.
at
(
1
).
dyn_dims
();
if
(
boxes_dims
.
at
(
1
)
!=
scores_dims
.
at
(
2
))
{
MIGRAPHX_THROW
(
"NonMaxSuppression: dynamic spatial dimension mismatch between "
"boxes and scores input"
);
}
if
(
boxes_dims
.
at
(
0
)
!=
scores_dims
.
at
(
0
))
{
MIGRAPHX_THROW
(
"NonMaxSuppression: dynamic number of batches mismatch between "
"boxes and scores input"
);
}
}
else
if
(
inputs
.
at
(
1
).
dynamic
())
{
// scores has dynamic shape, boxes fixed shape
// check that it is only a dynamic number of classes
const
auto
scores_dims
=
inputs
.
at
(
1
).
dyn_dims
();
const
auto
boxes_lens
=
inputs
.
at
(
0
).
lens
();
if
(
not
scores_dims
.
at
(
0
).
is_fixed
()
or
scores_dims
.
at
(
0
).
max
!=
boxes_lens
.
at
(
0
))
{
MIGRAPHX_THROW
(
"NonMaxSuppression: scores dynamic num_classes; num_batches not "
"fixed or mismatched"
);
}
if
(
not
scores_dims
.
at
(
2
).
is_fixed
()
or
scores_dims
.
at
(
2
).
max
!=
boxes_lens
.
at
(
1
))
{
MIGRAPHX_THROW
(
"NonMaxSuppression: scores dynamic num_classes; "
"spatial_dimension not fixed or mismatches"
);
}
}
else
{
fixed_shape_error_check
();
}
std
::
vector
<
shape
::
dynamic_dimension
>
out_lens
=
{};
out_lens
.
push_back
({
0
,
max_num_boxes
,
0
});
out_lens
.
push_back
({
3
,
3
,
0
});
return
{
shape
::
int64_type
,
out_lens
};
}
// check batch sizes
if
(
lens
[
0
]
!=
inputs
.
at
(
1
).
lens
()[
0
])
else
{
MIGRAPHX_THROW
(
"NonMaxSuppression: number of batches mismatch between boxes and scores input"
);
if
(
inputs
.
at
(
0
).
dynamic
()
or
inputs
.
at
(
1
).
dynamic
())
{
MIGRAPHX_THROW
(
"NonMaxSuppression: dynamic input shape with use_dyn_output set to false"
);
}
fixed_shape_error_check
();
std
::
vector
<
std
::
size_t
>
out_lens
=
{
max_num_boxes
,
3
};
return
{
shape
::
int64_type
,
out_lens
};
}
std
::
vector
<
int64_t
>
out_lens
(
2
);
out_lens
.
at
(
0
)
=
lens
.
at
(
1
);
out_lens
.
at
(
1
)
=
3
;
return
{
shape
::
int64_type
,
out_lens
};
}
struct
box
...
...
@@ -181,13 +237,13 @@ struct nonmaxsuppression
}
template
<
class
Output
,
class
Boxes
,
class
Scores
>
void
compute_nms
(
Output
output
,
Boxes
boxes
,
Scores
scores
,
const
shape
&
output_shape
,
std
::
size_t
max_output_boxes_per_class
,
double
iou_threshold
,
double
score_threshold
)
const
std
::
size_t
compute_nms
(
Output
output
,
Boxes
boxes
,
Scores
scores
,
const
shape
&
max_
output_shape
,
std
::
size_t
max_output_boxes_per_class
,
double
iou_threshold
,
double
score_threshold
)
const
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
const
auto
&
lens
=
scores
.
get_shape
().
lens
();
...
...
@@ -197,7 +253,7 @@ struct nonmaxsuppression
// boxes of a class with NMS applied [score, index]
std
::
vector
<
std
::
pair
<
double
,
int64_t
>>
selected_boxes_inside_class
;
std
::
vector
<
int64_t
>
selected_indices
;
selected_boxes_inside_class
.
reserve
(
output_shape
.
elements
());
selected_boxes_inside_class
.
reserve
(
max_
output_shape
.
elements
());
// iterate over batches and classes
shape
comp_s
{
shape
::
double_type
,
{
num_batches
,
num_classes
}};
shape_for_each
(
comp_s
,
[
&
](
auto
idx
)
{
...
...
@@ -210,7 +266,7 @@ struct nonmaxsuppression
auto
boxes_heap
=
filter_boxes_by_score
(
scores_start
,
num_boxes
,
score_threshold
);
selected_boxes_inside_class
.
clear
();
// Get the next box with top score, filter by iou_threshold
while
(
!
boxes_heap
.
empty
()
&&
while
(
not
boxes_heap
.
empty
()
&&
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
{
// Check with existing selected boxes for this class, remove box if it
...
...
@@ -237,11 +293,14 @@ struct nonmaxsuppression
}
});
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
return
selected_indices
.
size
()
/
3
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
// make buffer of maximum size
shape
max_output_shape
=
{
output_shape
.
type
(),
output_shape
.
max_lens
()};
argument
result
{
max_output_shape
};
std
::
size_t
max_output_boxes_per_class
=
(
args
.
size
()
>
2
)
?
(
args
.
at
(
2
).
at
<
std
::
size_t
>
())
:
0
;
...
...
@@ -249,22 +308,29 @@ struct nonmaxsuppression
{
return
result
;
}
double
iou_threshold
=
(
args
.
size
()
>
3
)
?
(
args
.
at
(
3
).
at
<
double
>
())
:
0.0
f
;
double
score_threshold
=
(
args
.
size
()
>
4
)
?
(
args
.
at
(
4
).
at
<
double
>
())
:
0.0
f
;
double
iou_threshold
=
(
args
.
size
()
>
3
)
?
(
args
.
at
(
3
).
at
<
double
>
())
:
0.0
f
;
double
score_threshold
=
(
args
.
size
()
>
4
)
?
(
args
.
at
(
4
).
at
<
double
>
())
:
0.0
f
;
std
::
size_t
num_selected
=
0
;
result
.
visit
([
&
](
auto
output
)
{
visit_all
(
args
[
0
],
args
[
1
])([
&
](
auto
boxes
,
auto
scores
)
{
compute_nms
(
output
,
boxes
,
scores
,
output_shape
,
max_output_boxes_per_class
,
iou_threshold
,
score_threshold
);
num_selected
=
compute_nms
(
output
,
boxes
,
scores
,
max_
output_shape
,
max_output_boxes_per_class
,
iou_threshold
,
score_threshold
);
});
});
return
result
;
if
(
use_dyn_output
)
{
return
result
.
reshape
({
output_shape
.
type
(),
{
num_selected
,
3
}});
}
else
{
return
result
;
}
}
};
...
...
src/include/migraphx/op/quant_dot.hpp
View file @
5ec8f913
...
...
@@ -49,13 +49,14 @@ struct quant_dot
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t"
);
}
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
{
MIGRAPHX_THROW
(
"QUANT_DOT: dot only accept 2 or more dims operands"
);
}
// only handle the case that the batch size of a and b are the same
if
(
!
std
::
equal
(
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
MIGRAPHX_THROW
(
"QUANT_DOT: batch size of A and B mismatch: {"
+
...
...
src/include/migraphx/op/slice.hpp
View file @
5ec8f913
...
...
@@ -78,7 +78,7 @@ struct slice
const
std
::
vector
<
std
::
size_t
>&
lens
=
s
.
lens
();
const
std
::
vector
<
std
::
size_t
>&
strides
=
s
.
strides
();
auto
offset
=
0
;
if
(
!
axes
.
empty
())
if
(
not
axes
.
empty
())
{
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
...
...
@@ -109,7 +109,7 @@ struct slice
MIGRAPHX_THROW
(
"SLICE: input axis "
+
to_string_range
(
axes
)
+
" out of range"
);
}
if
(
starts
.
size
()
!=
axes
.
size
()
||
axes
.
size
()
!=
ends
.
size
())
if
(
starts
.
size
()
!=
axes
.
size
()
or
axes
.
size
()
!=
ends
.
size
())
{
MIGRAPHX_THROW
(
"SLICE: inconsistent sizes"
);
}
...
...
src/include/migraphx/op/transpose.hpp
View file @
5ec8f913
...
...
@@ -59,7 +59,7 @@ struct transpose
}
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
!
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
if
(
not
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
{
MIGRAPHX_THROW
(
"TRANSPOSE: Invalid permutation"
);
}
...
...
src/include/migraphx/operation.hpp
View file @
5ec8f913
...
...
@@ -1066,7 +1066,7 @@ struct operation
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -1237,7 +1237,7 @@ struct operation
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
@@ -1276,7 +1276,7 @@ inline const ValueType& any_cast(const operation& x)
}
#endif
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
!
(
x
==
y
);
}
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
not
(
x
==
y
);
}
inline
value
compile
(
operation
&
op
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
...
...
src/include/migraphx/pass.hpp
View file @
5ec8f913
...
...
@@ -238,7 +238,7 @@ struct pass
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -292,7 +292,7 @@ struct pass
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
Prev
1
2
3
4
5
6
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment