Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
72011beb
Commit
72011beb
authored
Sep 06, 2022
by
Paul
Browse files
Merge branch 'develop' into jit-concat-pointwise
parents
d48d9bf7
d37a4df9
Changes
118
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
52 additions
and
50 deletions
+52
-50
src/include/migraphx/tune_axis.hpp
src/include/migraphx/tune_axis.hpp
+1
-1
src/instruction.cpp
src/instruction.cpp
+6
-6
src/module.cpp
src/module.cpp
+6
-6
src/normalize_attributes.cpp
src/normalize_attributes.cpp
+6
-5
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+5
-5
src/onnx/padding.cpp
src/onnx/padding.cpp
+2
-2
src/onnx/parse_cast.cpp
src/onnx/parse_cast.cpp
+1
-1
src/onnx/parse_constant_fill.cpp
src/onnx/parse_constant_fill.cpp
+1
-1
src/onnx/parse_gemm.cpp
src/onnx/parse_gemm.cpp
+1
-1
src/onnx/parse_lpnormalization.cpp
src/onnx/parse_lpnormalization.cpp
+1
-1
src/onnx/parse_matmul.cpp
src/onnx/parse_matmul.cpp
+2
-1
src/onnx/parse_mod.cpp
src/onnx/parse_mod.cpp
+2
-2
src/onnx/parse_nonzero.cpp
src/onnx/parse_nonzero.cpp
+1
-1
src/onnx/parse_pad.cpp
src/onnx/parse_pad.cpp
+1
-1
src/onnx/parse_pooling.cpp
src/onnx/parse_pooling.cpp
+2
-2
src/onnx/parse_pow.cpp
src/onnx/parse_pow.cpp
+1
-1
src/onnx/parse_resize.cpp
src/onnx/parse_resize.cpp
+2
-2
src/onnx/parse_reversesequence.cpp
src/onnx/parse_reversesequence.cpp
+1
-1
src/opt/memory_coloring.cpp
src/opt/memory_coloring.cpp
+1
-1
src/opt/memory_coloring_impl.cpp
src/opt/memory_coloring_impl.cpp
+9
-9
No files found.
src/include/migraphx/tune_axis.hpp
View file @
72011beb
...
@@ -34,7 +34,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -34,7 +34,7 @@ inline namespace MIGRAPHX_INLINE_NS {
inline
int
tune_axis
(
const
int
n_dim
,
const
int
axis
,
const
std
::
string
&
op_name
=
"OPERATOR"
)
inline
int
tune_axis
(
const
int
n_dim
,
const
int
axis
,
const
std
::
string
&
op_name
=
"OPERATOR"
)
{
{
if
(
axis
>=
n_dim
||
std
::
abs
(
axis
)
>
n_dim
)
if
(
axis
>=
n_dim
or
std
::
abs
(
axis
)
>
n_dim
)
{
{
MIGRAPHX_THROW
(
to_upper
(
op_name
)
+
": axis is out of range."
);
MIGRAPHX_THROW
(
to_upper
(
op_name
)
+
": axis is out of range."
);
}
}
...
...
src/instruction.cpp
View file @
72011beb
...
@@ -176,13 +176,13 @@ bool operator==(const instruction& x, const instruction& y)
...
@@ -176,13 +176,13 @@ bool operator==(const instruction& x, const instruction& y)
return
true
;
return
true
;
}
}
bool
operator
!=
(
const
instruction
&
x
,
const
instruction
&
y
)
{
return
!
(
x
==
y
);
}
bool
operator
!=
(
const
instruction
&
x
,
const
instruction
&
y
)
{
return
not
(
x
==
y
);
}
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
bool
operator
!=
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
!
(
i
==
ref
);
}
bool
operator
!=
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
not
(
i
==
ref
);
}
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
!
(
i
==
ref
);
}
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
not
(
i
==
ref
);
}
void
instruction
::
add_output
(
instruction_ref
ins
)
void
instruction
::
add_output
(
instruction_ref
ins
)
{
{
...
@@ -361,7 +361,7 @@ void instruction::print(std::ostream& os,
...
@@ -361,7 +361,7 @@ void instruction::print(std::ostream& os,
os
<<
"{"
<<
ins
->
get_literal
()
<<
"}"
;
os
<<
"{"
<<
ins
->
get_literal
()
<<
"}"
;
}
}
if
(
!
ins
->
inputs
().
empty
())
if
(
not
ins
->
inputs
().
empty
())
{
{
char
delim
=
'('
;
char
delim
=
'('
;
for
(
auto
&&
arg
:
ins
->
inputs
())
for
(
auto
&&
arg
:
ins
->
inputs
())
...
@@ -374,7 +374,7 @@ void instruction::print(std::ostream& os,
...
@@ -374,7 +374,7 @@ void instruction::print(std::ostream& os,
}
}
// print module inputs
// print module inputs
if
(
!
ins
->
module_inputs
().
empty
())
if
(
not
ins
->
module_inputs
().
empty
())
{
{
std
::
string
delim
=
", ["
;
std
::
string
delim
=
", ["
;
for
(
auto
&&
mod_arg
:
ins
->
module_inputs
())
for
(
auto
&&
mod_arg
:
ins
->
module_inputs
())
...
@@ -446,7 +446,7 @@ operation instruction::normalized_operator() const
...
@@ -446,7 +446,7 @@ operation instruction::normalized_operator() const
if
(
this
->
need_normalization
())
if
(
this
->
need_normalization
())
{
{
auto
s
=
this
->
inputs
().
front
()
->
get_shape
();
auto
s
=
this
->
inputs
().
front
()
->
get_shape
();
if
(
!
normalize_attributes
(
o
,
s
.
max_lens
()))
if
(
not
normalize_attributes
(
o
,
s
.
max_lens
()))
return
this
->
get_operator
();
return
this
->
get_operator
();
}
}
return
o
;
return
o
;
...
...
src/module.cpp
View file @
72011beb
...
@@ -141,12 +141,12 @@ void module::set_bypass(bool b) { impl->bypass = b; }
...
@@ -141,12 +141,12 @@ void module::set_bypass(bool b) { impl->bypass = b; }
void
module
::
assign
(
const
module
&
m
)
void
module
::
assign
(
const
module
&
m
)
{
{
// copy the impl
// copy the impl
if
(
!
impl
)
if
(
not
impl
)
impl
=
std
::
make_unique
<
module_impl
>
();
impl
=
std
::
make_unique
<
module_impl
>
();
*
impl
=
*
m
.
impl
;
*
impl
=
*
m
.
impl
;
// clear instructions
// clear instructions
if
(
!
impl
->
instructions
.
empty
())
if
(
not
impl
->
instructions
.
empty
())
{
{
impl
->
clear
();
impl
->
clear
();
}
}
...
@@ -346,7 +346,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
...
@@ -346,7 +346,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
assert
(
out
->
valid
(
begin
()));
assert
(
out
->
valid
(
begin
()));
}
}
// Replacement should not be dead code unless its the last instruction
// Replacement should not be dead code unless its the last instruction
assert
(
!
rep
->
outputs
().
empty
()
or
rep
==
std
::
prev
(
end
()));
assert
(
not
rep
->
outputs
().
empty
()
or
rep
==
std
::
prev
(
end
()));
// Output of the original instruction should only be the replacement or empty
// Output of the original instruction should only be the replacement or empty
assert
(
ins
->
outputs
().
empty
()
or
std
::
all_of
(
ins
->
outputs
().
begin
(),
assert
(
ins
->
outputs
().
empty
()
or
std
::
all_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
ins
->
outputs
().
end
(),
...
@@ -598,7 +598,7 @@ instruction_ref module::validate() const
...
@@ -598,7 +598,7 @@ instruction_ref module::validate() const
auto
inputs
=
i
.
inputs
();
auto
inputs
=
i
.
inputs
();
bool
check_order
=
std
::
all_of
(
bool
check_order
=
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
in
)
{
return
has_instruction
(
in
);
});
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
in
)
{
return
has_instruction
(
in
);
});
return
!
i
.
valid
(
impl
->
instructions
.
begin
(),
check_order
);
return
not
i
.
valid
(
impl
->
instructions
.
begin
(),
check_order
);
});
});
}
}
...
@@ -754,7 +754,7 @@ void module::print_graph(std::ostream& os, bool brief) const
...
@@ -754,7 +754,7 @@ void module::print_graph(std::ostream& os, bool brief) const
label
=
to_string
(
ins
->
get_operator
());
label
=
to_string
(
ins
->
get_operator
());
os
<<
"
\t
"
<<
enclose_name
(
ins_names
.
at
(
ins
))
<<
"[label="
<<
enclose_name
(
label
)
<<
"]"
;
os
<<
"
\t
"
<<
enclose_name
(
ins_names
.
at
(
ins
))
<<
"[label="
<<
enclose_name
(
label
)
<<
"]"
;
os
<<
";"
<<
std
::
endl
;
os
<<
";"
<<
std
::
endl
;
if
(
!
ins
->
inputs
().
empty
())
if
(
not
ins
->
inputs
().
empty
())
{
{
for
(
auto
&&
arg
:
ins
->
inputs
())
for
(
auto
&&
arg
:
ins
->
inputs
())
{
{
...
@@ -908,7 +908,7 @@ module& module::sort()
...
@@ -908,7 +908,7 @@ module& module::sort()
this
->
move_instruction
(
ins
,
this
->
begin
());
this
->
move_instruction
(
ins
,
this
->
begin
());
for
(
auto
child
:
ins
->
inputs
())
for
(
auto
child
:
ins
->
inputs
())
{
{
if
(
!
contains
(
this
->
impl
->
instructions
,
child
))
if
(
not
contains
(
this
->
impl
->
instructions
,
child
))
{
{
continue
;
continue
;
}
}
...
...
src/normalize_attributes.cpp
View file @
72011beb
...
@@ -79,14 +79,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
...
@@ -79,14 +79,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
{
if
(
contains
(
vec_attrs
,
op
::
normalize_attribute
::
include_max
))
if
(
contains
(
vec_attrs
,
op
::
normalize_attribute
::
include_max
))
{
{
if
(
!
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less_equal
<>
{}))
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less_equal
<>
{}))
{
{
MIGRAPHX_THROW
(
"TUNE_VECTOR: value out of range!"
);
MIGRAPHX_THROW
(
"TUNE_VECTOR: value out of range!"
);
}
}
}
}
else
else
{
{
if
(
!
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less
<>
{}))
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less
<>
{}))
{
{
MIGRAPHX_THROW
(
"TUNE_VECTOR: value out of range!"
);
MIGRAPHX_THROW
(
"TUNE_VECTOR: value out of range!"
);
}
}
...
@@ -118,14 +118,15 @@ auto tune_attribute(const std::vector<int64_t>& vec,
...
@@ -118,14 +118,15 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
{
if
(
contains
(
vec_attrs
,
op
::
normalize_attribute
::
include_min
))
if
(
contains
(
vec_attrs
,
op
::
normalize_attribute
::
include_min
))
{
{
if
(
!
std
::
equal
(
min_vals
.
begin
(),
min_vals
.
end
(),
result
.
begin
(),
std
::
less_equal
<>
{}))
if
(
not
std
::
equal
(
min_vals
.
begin
(),
min_vals
.
end
(),
result
.
begin
(),
std
::
less_equal
<>
{}))
{
{
MIGRAPHX_THROW
(
"TUNE_VECTOR: attribute out of range!"
);
MIGRAPHX_THROW
(
"TUNE_VECTOR: attribute out of range!"
);
}
}
}
}
else
else
{
{
if
(
!
std
::
equal
(
result
.
begin
(),
result
.
end
(),
min_vals
.
begin
(),
std
::
less
<>
{}))
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
min_vals
.
begin
(),
std
::
less
<>
{}))
{
{
MIGRAPHX_THROW
(
"TUNE_VECTOR: attribute out of range!"
);
MIGRAPHX_THROW
(
"TUNE_VECTOR: attribute out of range!"
);
}
}
...
@@ -174,7 +175,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
...
@@ -174,7 +175,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
tuned
=
true
;
tuned
=
true
;
}
}
}
}
if
(
!
attrs
.
contains
(
"normalize_axes"
))
if
(
not
attrs
.
contains
(
"normalize_axes"
))
{
{
return
tuned
;
return
tuned
;
}
}
...
...
src/onnx/onnx_parser.cpp
View file @
72011beb
...
@@ -187,7 +187,7 @@ operation onnx_parser::load(const std::string& name, const node_info& info) cons
...
@@ -187,7 +187,7 @@ operation onnx_parser::load(const std::string& name, const node_info& info) cons
void
onnx_parser
::
parse_undefined
(
module
*
mod
,
const
std
::
string
&
name
)
void
onnx_parser
::
parse_undefined
(
module
*
mod
,
const
std
::
string
&
name
)
{
{
if
(
!
contains
(
instructions
,
name
))
if
(
not
contains
(
instructions
,
name
))
{
{
auto
ins
=
mod
->
add_instruction
(
make_op
(
"undefined"
));
auto
ins
=
mod
->
add_instruction
(
make_op
(
"undefined"
));
instructions
[
name
]
=
ins
;
instructions
[
name
]
=
ins
;
...
@@ -267,7 +267,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
...
@@ -267,7 +267,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{
{
const
std
::
string
&
name
=
input
.
name
();
const
std
::
string
&
name
=
input
.
name
();
// input not in initializer_data, so it is a real input
// input not in initializer_data, so it is a real input
if
(
!
contains
(
mod_insts
,
name
))
if
(
not
contains
(
mod_insts
,
name
))
{
{
// ONNX specification does not specify how to deal with the
// ONNX specification does not specify how to deal with the
// scenario that a nested subgraph contains a parameter with the
// scenario that a nested subgraph contains a parameter with the
...
@@ -354,7 +354,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
...
@@ -354,7 +354,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
all_output_names
.
begin
(),
all_output_names
.
begin
(),
all_output_names
.
end
(),
all_output_names
.
end
(),
std
::
back_inserter
(
prog_output_names
),
std
::
back_inserter
(
prog_output_names
),
[
&
](
const
auto
&
name
)
{
return
!
(
name
.
empty
()
or
instructions
.
count
(
name
)
==
0
);
});
[
&
](
const
auto
&
name
)
{
return
not
(
name
.
empty
()
or
instructions
.
count
(
name
)
==
0
);
});
std
::
vector
<
instruction_ref
>
output_ins
;
std
::
vector
<
instruction_ref
>
output_ins
;
std
::
transform
(
prog_output_names
.
begin
(),
std
::
transform
(
prog_output_names
.
begin
(),
...
@@ -444,7 +444,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
...
@@ -444,7 +444,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
{
{
shape
::
type_t
shape_type
=
get_type
(
t
.
tensor_type
().
elem_type
());
shape
::
type_t
shape_type
=
get_type
(
t
.
tensor_type
().
elem_type
());
if
(
!
input_dims
.
empty
())
if
(
not
input_dims
.
empty
())
{
{
return
{
shape_type
,
input_dims
};
return
{
shape_type
,
input_dims
};
}
}
...
@@ -511,7 +511,7 @@ shape::type_t get_type(int dtype)
...
@@ -511,7 +511,7 @@ shape::type_t get_type(int dtype)
bool
is_type_float
(
shape
::
type_t
dtype
)
bool
is_type_float
(
shape
::
type_t
dtype
)
{
{
bool
r
=
false
;
bool
r
=
false
;
if
(
dtype
==
shape
::
float_type
||
dtype
==
shape
::
double_type
||
dtype
==
shape
::
half_type
)
if
(
dtype
==
shape
::
float_type
or
dtype
==
shape
::
double_type
or
dtype
==
shape
::
half_type
)
{
{
r
=
true
;
r
=
true
;
}
}
...
...
src/onnx/padding.cpp
View file @
72011beb
...
@@ -42,7 +42,7 @@ void cal_auto_padding_size(onnx_parser::node_info info,
...
@@ -42,7 +42,7 @@ void cal_auto_padding_size(onnx_parser::node_info info,
size_t
kdims
=
in_lens
.
size
()
-
2
;
size_t
kdims
=
in_lens
.
size
()
-
2
;
assert
(
k_lens
.
size
()
==
kdims
and
dilation
.
size
()
==
kdims
);
assert
(
k_lens
.
size
()
==
kdims
and
dilation
.
size
()
==
kdims
);
if
(
!
contains
(
info
.
attributes
,
"auto_pad"
))
if
(
not
contains
(
info
.
attributes
,
"auto_pad"
))
{
{
return
;
return
;
}
}
...
@@ -124,7 +124,7 @@ void tune_padding_size(const value& v,
...
@@ -124,7 +124,7 @@ void tune_padding_size(const value& v,
}
}
// if padding is symmetric, return directly
// if padding is symmetric, return directly
if
(
!
is_asym_padding
(
padding
))
if
(
not
is_asym_padding
(
padding
))
{
{
return
;
return
;
}
}
...
...
src/onnx/parse_cast.cpp
View file @
72011beb
...
@@ -38,7 +38,7 @@ struct parse_cast : op_parser<parse_cast>
...
@@ -38,7 +38,7 @@ struct parse_cast : op_parser<parse_cast>
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
{
if
(
!
contains
(
info
.
attributes
,
"to"
))
if
(
not
contains
(
info
.
attributes
,
"to"
))
{
{
MIGRAPHX_THROW
(
"PARSE_CAST: missing to type attribute!"
);
MIGRAPHX_THROW
(
"PARSE_CAST: missing to type attribute!"
);
}
}
...
...
src/onnx/parse_constant_fill.cpp
View file @
72011beb
...
@@ -93,7 +93,7 @@ struct parse_constant_fill : op_parser<parse_constant_fill>
...
@@ -93,7 +93,7 @@ struct parse_constant_fill : op_parser<parse_constant_fill>
}
}
else
if
(
input_as_shape
==
0
)
else
if
(
input_as_shape
==
0
)
{
{
if
(
!
contains
(
info
.
attributes
,
"shape"
))
if
(
not
contains
(
info
.
attributes
,
"shape"
))
{
{
MIGRAPHX_THROW
(
"ConstantFill: attribute output shape is needed"
);
MIGRAPHX_THROW
(
"ConstantFill: attribute output shape is needed"
);
}
}
...
...
src/onnx/parse_gemm.cpp
View file @
72011beb
...
@@ -94,7 +94,7 @@ struct parse_gemm : op_parser<parse_gemm>
...
@@ -94,7 +94,7 @@ struct parse_gemm : op_parser<parse_gemm>
out_lens
.
back
()
=
l2
->
get_shape
().
lens
().
back
();
out_lens
.
back
()
=
l2
->
get_shape
().
lens
().
back
();
auto
l3
=
args
[
2
];
auto
l3
=
args
[
2
];
auto
l3_lens
=
l3
->
get_shape
().
lens
();
auto
l3_lens
=
l3
->
get_shape
().
lens
();
if
(
!
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
l3_lens
.
begin
(),
l3_lens
.
end
()))
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
l3_lens
.
begin
(),
l3_lens
.
end
()))
{
{
l3
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
l3
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
args
[
2
]);
...
...
src/onnx/parse_lpnormalization.cpp
View file @
72011beb
...
@@ -31,7 +31,7 @@ namespace migraphx {
...
@@ -31,7 +31,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
namespace
onnx
{
//
!
Parser for LpNormalization ONNX operator.
// Parser for LpNormalization ONNX operator.
/*!
/*!
Normalizes a tensor by the L1 or L2 norms along a given axis.
Normalizes a tensor by the L1 or L2 norms along a given axis.
Norms that evaluate to 0 are changed to 1 to prevent division by zero.
Norms that evaluate to 0 are changed to 1 to prevent division by zero.
...
...
src/onnx/parse_matmul.cpp
View file @
72011beb
...
@@ -67,7 +67,8 @@ struct parse_matmul : op_parser<parse_matmul>
...
@@ -67,7 +67,8 @@ struct parse_matmul : op_parser<parse_matmul>
instruction_ref
bl0
=
l0
;
instruction_ref
bl0
=
l0
;
instruction_ref
bl1
=
l1
;
instruction_ref
bl1
=
l1
;
if
(
!
std
::
equal
(
l0_lens
.
rbegin
()
+
2
,
l0_lens
.
rend
(),
l1_lens
.
rbegin
()
+
2
,
l1_lens
.
rend
()))
if
(
not
std
::
equal
(
l0_lens
.
rbegin
()
+
2
,
l0_lens
.
rend
(),
l1_lens
.
rbegin
()
+
2
,
l1_lens
.
rend
()))
{
{
auto
l0_it
=
l0_lens
.
begin
()
+
l0_lens
.
size
()
-
2
;
auto
l0_it
=
l0_lens
.
begin
()
+
l0_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
l0_lens
.
begin
(),
l0_it
);
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
l0_lens
.
begin
(),
l0_it
);
...
...
src/onnx/parse_mod.cpp
View file @
72011beb
...
@@ -40,9 +40,9 @@ struct parse_mod : op_parser<parse_mod>
...
@@ -40,9 +40,9 @@ struct parse_mod : op_parser<parse_mod>
std
::
vector
<
instruction_ref
>
args
)
const
std
::
vector
<
instruction_ref
>
args
)
const
{
{
std
::
string
mod
=
"mod"
;
std
::
string
mod
=
"mod"
;
if
(
is_type_float
(
args
[
0
]
->
get_shape
().
type
())
||
is_type_float
(
args
[
1
]
->
get_shape
().
type
()))
if
(
is_type_float
(
args
[
0
]
->
get_shape
().
type
())
or
is_type_float
(
args
[
1
]
->
get_shape
().
type
()))
{
{
if
(
!
contains
(
info
.
attributes
,
"fmod"
))
if
(
not
contains
(
info
.
attributes
,
"fmod"
))
{
{
MIGRAPHX_THROW
(
"Mod operator with float args and fmod=0 invalid"
);
MIGRAPHX_THROW
(
"Mod operator with float args and fmod=0 invalid"
);
}
}
...
...
src/onnx/parse_nonzero.cpp
View file @
72011beb
...
@@ -37,7 +37,7 @@ static std::vector<std::size_t> nonzero_indices(const std::vector<T>& data)
...
@@ -37,7 +37,7 @@ static std::vector<std::size_t> nonzero_indices(const std::vector<T>& data)
std
::
vector
<
std
::
size_t
>
indices
;
std
::
vector
<
std
::
size_t
>
indices
;
for
(
std
::
size_t
i
=
0
;
i
<
data
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
data
.
size
();
++
i
)
{
{
if
(
!
float_equal
(
data
[
i
],
0
))
if
(
not
float_equal
(
data
[
i
],
0
))
indices
.
push_back
(
i
);
indices
.
push_back
(
i
);
}
}
...
...
src/onnx/parse_pad.cpp
View file @
72011beb
...
@@ -160,7 +160,7 @@ struct parse_pad : op_parser<parse_pad>
...
@@ -160,7 +160,7 @@ struct parse_pad : op_parser<parse_pad>
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
auto
val_ins
=
args
.
at
(
2
);
auto
val_ins
=
args
.
at
(
2
);
if
(
!
val_ins
->
can_eval
())
if
(
not
val_ins
->
can_eval
())
{
{
MIGRAPHX_THROW
(
"PARSE_PAD: input value must be constant"
);
MIGRAPHX_THROW
(
"PARSE_PAD: input value must be constant"
);
}
}
...
...
src/onnx/parse_pooling.cpp
View file @
72011beb
...
@@ -157,7 +157,7 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -157,7 +157,7 @@ struct parse_pooling : op_parser<parse_pooling>
std
::
vector
<
int64_t
>
slice_end
;
std
::
vector
<
int64_t
>
slice_end
;
tune_padding_size
(
values
,
paddings
,
count_include_pad
,
slice_start
);
tune_padding_size
(
values
,
paddings
,
count_include_pad
,
slice_start
);
if
(
!
slice_start
.
empty
())
if
(
not
slice_start
.
empty
())
{
{
// calculate expected output shape
// calculate expected output shape
orig_padding
.
insert
(
orig_padding
.
begin
()
+
kdims
,
2
,
0
);
orig_padding
.
insert
(
orig_padding
.
begin
()
+
kdims
,
2
,
0
);
...
@@ -180,7 +180,7 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -180,7 +180,7 @@ struct parse_pooling : op_parser<parse_pooling>
op
.
from_value
(
values
);
op
.
from_value
(
values
);
auto
l1
=
info
.
add_instruction
(
op
,
l0
);
auto
l1
=
info
.
add_instruction
(
op
,
l0
);
if
(
!
slice_start
.
empty
())
if
(
not
slice_start
.
empty
())
{
{
std
::
vector
<
int64_t
>
axes
(
kdims
);
std
::
vector
<
int64_t
>
axes
(
kdims
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
2
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
2
);
...
...
src/onnx/parse_pow.cpp
View file @
72011beb
...
@@ -46,7 +46,7 @@ auto compute_type(shape::type_t t1, shape::type_t t2)
...
@@ -46,7 +46,7 @@ auto compute_type(shape::type_t t1, shape::type_t t2)
int
it1
=
t1
;
int
it1
=
t1
;
int
it2
=
t2
;
int
it2
=
t2
;
if
(
!
contains
(
op_order
,
it1
)
or
!
contains
(
op_order
,
it2
))
if
(
not
contains
(
op_order
,
it1
)
or
not
contains
(
op_order
,
it2
))
{
{
MIGRAPHX_THROW
(
"PARSE_POW: Input data type not supported!"
);
MIGRAPHX_THROW
(
"PARSE_POW: Input data type not supported!"
);
}
}
...
...
src/onnx/parse_resize.cpp
View file @
72011beb
...
@@ -56,7 +56,7 @@ const auto& get_nearest_op(const std::string& mode)
...
@@ -56,7 +56,7 @@ const auto& get_nearest_op(const std::string& mode)
return
static_cast
<
std
::
size_t
>
(
std
::
ceil
((
val
)));
return
static_cast
<
std
::
size_t
>
(
std
::
ceil
((
val
)));
}}};
}}};
if
(
!
contains
(
nearest_ops
,
mode
))
if
(
not
contains
(
nearest_ops
,
mode
))
{
{
MIGRAPHX_THROW
(
"PARSE_RESIZE: nearest_mode "
+
mode
+
" not supported!"
);
MIGRAPHX_THROW
(
"PARSE_RESIZE: nearest_mode "
+
mode
+
" not supported!"
);
}
}
...
@@ -86,7 +86,7 @@ const auto& get_original_idx_op(const std::string& mode)
...
@@ -86,7 +86,7 @@ const auto& get_original_idx_op(const std::string& mode)
return
(
idx
+
0.5
)
/
scale
;
return
(
idx
+
0.5
)
/
scale
;
}}};
}}};
if
(
!
contains
(
idx_ops
,
mode
))
if
(
not
contains
(
idx_ops
,
mode
))
{
{
MIGRAPHX_THROW
(
"PARSE_RESIZE: coordinate_transformation_mode "
+
mode
+
" not supported!"
);
MIGRAPHX_THROW
(
"PARSE_RESIZE: coordinate_transformation_mode "
+
mode
+
" not supported!"
);
}
}
...
...
src/onnx/parse_reversesequence.cpp
View file @
72011beb
...
@@ -31,7 +31,7 @@ namespace migraphx {
...
@@ -31,7 +31,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
namespace
onnx
{
//
!
Parser for ReverseSequence ONNX operator.
// Parser for ReverseSequence ONNX operator.
/*!
/*!
Reverses the data along the time axis for the batches along the batch axis.
Reverses the data along the time axis for the batches along the batch axis.
The sequence lengths can be given to reverse up to the given length for each batch, keeping the
The sequence lengths can be given to reverse up to the given length for each batch, keeping the
...
...
src/opt/memory_coloring.cpp
View file @
72011beb
...
@@ -29,7 +29,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -29,7 +29,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void
memory_coloring
::
apply
(
module
&
m
)
const
void
memory_coloring
::
apply
(
module
&
m
)
const
{
{
if
(
!
enabled
(
MIGRAPHX_DISABLE_MEMORY_COLORING
{}))
if
(
not
enabled
(
MIGRAPHX_DISABLE_MEMORY_COLORING
{}))
{
{
memory_coloring_impl
opt
(
&
m
,
allocation_op
,
verify
);
memory_coloring_impl
opt
(
&
m
,
allocation_op
,
verify
);
opt
.
run
();
opt
.
run
();
...
...
src/opt/memory_coloring_impl.cpp
View file @
72011beb
...
@@ -42,7 +42,7 @@ void memory_coloring_impl::run()
...
@@ -42,7 +42,7 @@ void memory_coloring_impl::run()
{
{
MIGRAPHX_DEBUG
(
dump_intervals
());
MIGRAPHX_DEBUG
(
dump_intervals
());
// Coloring
// Coloring
while
(
!
alloc_queue
.
empty
())
while
(
not
alloc_queue
.
empty
())
{
{
interval_ptr
interval
=
alloc_queue
.
top
();
interval_ptr
interval
=
alloc_queue
.
top
();
allocate
(
interval
);
allocate
(
interval
);
...
@@ -96,7 +96,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
...
@@ -96,7 +96,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
}
}
std
::
size_t
offset
=
0
;
std
::
size_t
offset
=
0
;
while
(
!
conflict_queue
.
empty
())
while
(
not
conflict_queue
.
empty
())
{
{
live_range
*
range
=
conflict_queue
.
top
();
live_range
*
range
=
conflict_queue
.
top
();
std
::
size_t
iter_offset
=
range
->
offset
;
std
::
size_t
iter_offset
=
range
->
offset
;
...
@@ -149,7 +149,7 @@ void memory_coloring_impl::build()
...
@@ -149,7 +149,7 @@ void memory_coloring_impl::build()
{
{
def_interval
=
instr2_live
[
p_iter
];
def_interval
=
instr2_live
[
p_iter
];
bool
is_lit
=
is_literal
(
iter
);
bool
is_lit
=
is_literal
(
iter
);
if
(
is_allocate
(
iter
)
||
is_lit
)
if
(
is_allocate
(
iter
)
or
is_lit
)
{
{
live_range
&
range
=
def_interval
->
segment
;
live_range
&
range
=
def_interval
->
segment
;
def_interval
->
result
=
iter
->
get_shape
();
def_interval
->
result
=
iter
->
get_shape
();
...
@@ -157,12 +157,12 @@ void memory_coloring_impl::build()
...
@@ -157,12 +157,12 @@ void memory_coloring_impl::build()
range
.
begin
=
cur_points
;
range
.
begin
=
cur_points
;
def_interval
->
def_point
=
cur_points
;
def_interval
->
def_point
=
cur_points
;
range
.
size
=
(
iter
->
get_shape
()).
bytes
();
range
.
size
=
(
iter
->
get_shape
()).
bytes
();
if
(
!
is_lit
||
unify_literals
)
if
(
not
is_lit
or
unify_literals
)
alloc_queue
.
push
(
def_interval
);
alloc_queue
.
push
(
def_interval
);
live_set
.
erase
(
range
.
vn
);
live_set
.
erase
(
range
.
vn
);
}
}
}
}
else
if
(
!
is_param
(
iter
)
&&
!
is_outline
(
iter
)
&&
!
is_check_context
(
iter
))
else
if
(
not
is_param
(
iter
)
&&
not
is_outline
(
iter
)
&&
not
is_check_context
(
iter
))
{
{
is_dead
=
true
;
is_dead
=
true
;
}
}
...
@@ -179,7 +179,7 @@ void memory_coloring_impl::build()
...
@@ -179,7 +179,7 @@ void memory_coloring_impl::build()
if
(
not
p_mod
->
has_instruction
(
arg
))
if
(
not
p_mod
->
has_instruction
(
arg
))
continue
;
continue
;
if
(
is_param
(
arg
)
||
is_outline
(
arg
))
if
(
is_param
(
arg
)
or
is_outline
(
arg
))
{
{
if
(
is_output_param
(
arg
))
if
(
is_output_param
(
arg
))
is_dead
=
false
;
is_dead
=
false
;
...
@@ -235,7 +235,7 @@ void memory_coloring_impl::rewrite()
...
@@ -235,7 +235,7 @@ void memory_coloring_impl::rewrite()
if
(
interval
->
get_begin
()
==
invalid_offset
)
if
(
interval
->
get_begin
()
==
invalid_offset
)
continue
;
continue
;
if
(
!
unify_literals
&&
interval
->
is_literal
)
if
(
not
unify_literals
&&
interval
->
is_literal
)
continue
;
continue
;
std
::
size_t
offset
=
0
;
std
::
size_t
offset
=
0
;
...
@@ -272,7 +272,7 @@ void memory_coloring_impl::verify()
...
@@ -272,7 +272,7 @@ void memory_coloring_impl::verify()
if
(
segment
.
begin
==
invalid_offset
)
if
(
segment
.
begin
==
invalid_offset
)
{
{
// if(
!
interval.is_live_on_entry)
// if(
not
interval.is_live_on_entry)
// MIGRAPHX_THROW("interval is not live on entry");
// MIGRAPHX_THROW("interval is not live on entry");
continue
;
continue
;
}
}
...
@@ -290,7 +290,7 @@ void memory_coloring_impl::verify()
...
@@ -290,7 +290,7 @@ void memory_coloring_impl::verify()
live_range
*
range
=
live_ranges
[
iter
];
live_range
*
range
=
live_ranges
[
iter
];
if
(
range
->
offset
==
invalid_offset
)
if
(
range
->
offset
==
invalid_offset
)
continue
;
continue
;
if
(
!
is_disjoin
(
*
range
,
segment
))
if
(
not
is_disjoin
(
*
range
,
segment
))
MIGRAPHX_THROW
(
"range and segment is not disjoined"
);
MIGRAPHX_THROW
(
"range and segment is not disjoined"
);
}
}
}
}
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment