Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7f97b8ef
Unverified
Commit
7f97b8ef
authored
Oct 07, 2022
by
Ted Themistokleous
Committed by
GitHub
Oct 07, 2022
Browse files
Merge branch 'simplify_1_mul_div_ops' into divide_by_zero_check
parents
2ba401f0
d1fed367
Changes
448
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
208 additions
and
80 deletions
+208
-80
src/include/migraphx/value.hpp
src/include/migraphx/value.hpp
+6
-0
src/insert_pad.cpp
src/insert_pad.cpp
+6
-0
src/instruction.cpp
src/instruction.cpp
+7
-7
src/make_op.cpp
src/make_op.cpp
+5
-0
src/module.cpp
src/module.cpp
+18
-11
src/normalize_attributes.cpp
src/normalize_attributes.cpp
+6
-5
src/normalize_ops.cpp
src/normalize_ops.cpp
+2
-2
src/onnx/include/migraphx/onnx/onnx_parser.hpp
src/onnx/include/migraphx/onnx/onnx_parser.hpp
+2
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+7
-0
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+16
-12
src/onnx/padding.cpp
src/onnx/padding.cpp
+2
-2
src/onnx/parse_batchnorm.cpp
src/onnx/parse_batchnorm.cpp
+49
-14
src/onnx/parse_cast.cpp
src/onnx/parse_cast.cpp
+1
-1
src/onnx/parse_constant.cpp
src/onnx/parse_constant.cpp
+1
-1
src/onnx/parse_constant_fill.cpp
src/onnx/parse_constant_fill.cpp
+1
-1
src/onnx/parse_convolution.cpp
src/onnx/parse_convolution.cpp
+63
-17
src/onnx/parse_gemm.cpp
src/onnx/parse_gemm.cpp
+1
-1
src/onnx/parse_generic_op.cpp
src/onnx/parse_generic_op.cpp
+1
-2
src/onnx/parse_if.cpp
src/onnx/parse_if.cpp
+4
-2
src/onnx/parse_instancenorm.cpp
src/onnx/parse_instancenorm.cpp
+10
-2
No files found.
src/include/migraphx/value.hpp
View file @
7f97b8ef
...
...
@@ -184,6 +184,12 @@ struct value
{
}
explicit
binary
(
std
::
size_t
s
)
:
base
(
s
)
{}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
binary
&
obj
)
{
os
<<
"{binary_object: "
<<
obj
.
size
()
<<
"}"
;
return
os
;
}
};
value
()
=
default
;
...
...
src/insert_pad.cpp
View file @
7f97b8ef
...
...
@@ -40,6 +40,12 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
auto
val
=
op
.
to_value
();
auto
op_padding
=
val
.
at
(
"padding"
).
to_vector
<
size_t
>
();
// skip if shape is dynamic
if
(
input
->
get_shape
().
dynamic
())
{
return
;
}
auto
kdims
=
input
->
get_shape
().
lens
().
size
()
-
2
;
if
(
std
::
equal
(
op_padding
.
begin
(),
op_padding
.
begin
()
+
kdims
,
...
...
src/instruction.cpp
View file @
7f97b8ef
...
...
@@ -176,13 +176,13 @@ bool operator==(const instruction& x, const instruction& y)
return
true
;
}
bool
operator
!=
(
const
instruction
&
x
,
const
instruction
&
y
)
{
return
!
(
x
==
y
);
}
bool
operator
!=
(
const
instruction
&
x
,
const
instruction
&
y
)
{
return
not
(
x
==
y
);
}
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
bool
operator
!=
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
!
(
i
==
ref
);
}
bool
operator
!=
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
not
(
i
==
ref
);
}
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
!
(
i
==
ref
);
}
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
not
(
i
==
ref
);
}
void
instruction
::
add_output
(
instruction_ref
ins
)
{
...
...
@@ -361,7 +361,7 @@ void instruction::print(std::ostream& os,
os
<<
"{"
<<
ins
->
get_literal
()
<<
"}"
;
}
if
(
!
ins
->
inputs
().
empty
())
if
(
not
ins
->
inputs
().
empty
())
{
char
delim
=
'('
;
for
(
auto
&&
arg
:
ins
->
inputs
())
...
...
@@ -374,7 +374,7 @@ void instruction::print(std::ostream& os,
}
// print module inputs
if
(
!
ins
->
module_inputs
().
empty
())
if
(
not
ins
->
module_inputs
().
empty
())
{
std
::
string
delim
=
", ["
;
for
(
auto
&&
mod_arg
:
ins
->
module_inputs
())
...
...
@@ -445,8 +445,8 @@ operation instruction::normalized_operator() const
operation
o
=
this
->
get_operator
();
if
(
this
->
need_normalization
())
{
auto
len
s
=
this
->
inputs
().
front
()
->
get_shape
()
.
lens
()
;
if
(
!
normalize_attributes
(
o
,
lens
))
auto
s
=
this
->
inputs
().
front
()
->
get_shape
();
if
(
not
normalize_attributes
(
o
,
s
.
max_
lens
()
))
return
this
->
get_operator
();
}
return
o
;
...
...
src/make_op.cpp
View file @
7f97b8ef
...
...
@@ -64,5 +64,10 @@ operation make_op_from_value(const std::string& name, const value& v)
});
}
operation
make_json_op
(
const
std
::
string
&
name
,
const
std
::
string
&
s
)
{
return
make_op
(
name
,
from_json_string
(
convert_to_json
(
s
)));
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/module.cpp
View file @
7f97b8ef
...
...
@@ -141,12 +141,12 @@ void module::set_bypass(bool b) { impl->bypass = b; }
void
module
::
assign
(
const
module
&
m
)
{
// copy the impl
if
(
!
impl
)
if
(
not
impl
)
impl
=
std
::
make_unique
<
module_impl
>
();
*
impl
=
*
m
.
impl
;
// clear instructions
if
(
!
impl
->
instructions
.
empty
())
if
(
not
impl
->
instructions
.
empty
())
{
impl
->
clear
();
}
...
...
@@ -357,7 +357,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
assert
(
out
->
valid
(
begin
()));
}
// Replacement should not be dead code unless its the last instruction
assert
(
!
rep
->
outputs
().
empty
()
or
rep
==
std
::
prev
(
end
()));
assert
(
not
rep
->
outputs
().
empty
()
or
rep
==
std
::
prev
(
end
()));
// Output of the original instruction should only be the replacement or empty
assert
(
ins
->
outputs
().
empty
()
or
std
::
all_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
...
...
@@ -396,9 +396,13 @@ instruction_ref module::move_instruction(instruction_ref src, instruction_ref ds
instruction_ref
module
::
move_instructions
(
instruction_ref
src
,
instruction_ref
dst
)
{
this
->
move_instruction
(
src
,
dst
);
for
(
auto
ins
:
src
->
inputs
())
this
->
move_instruction
(
ins
,
src
);
{
if
(
not
contains
(
this
->
impl
->
instructions
,
ins
))
continue
;
this
->
move_instructions
(
ins
,
dst
);
}
this
->
move_instruction
(
src
,
dst
);
return
src
;
}
...
...
@@ -623,7 +627,7 @@ instruction_ref module::validate() const
auto
inputs
=
i
.
inputs
();
bool
check_order
=
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
in
)
{
return
has_instruction
(
in
);
});
return
!
i
.
valid
(
impl
->
instructions
.
begin
(),
check_order
);
return
not
i
.
valid
(
impl
->
instructions
.
begin
(),
check_order
);
});
}
...
...
@@ -788,7 +792,7 @@ void module::print_graph(std::ostream& os, bool brief) const
label
=
to_string
(
ins
->
get_operator
());
os
<<
"
\t
"
<<
enclose_name
(
ins_names
.
at
(
ins
))
<<
"[label="
<<
enclose_name
(
label
)
<<
"]"
;
os
<<
";"
<<
std
::
endl
;
if
(
!
ins
->
inputs
().
empty
())
if
(
not
ins
->
inputs
().
empty
())
{
for
(
auto
&&
arg
:
ins
->
inputs
())
{
...
...
@@ -822,12 +826,15 @@ static std::string cpp_var_name(const std::string& name)
static
void
print_make_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
os
<<
"migraphx::make_op("
<<
enclose_name
(
op
.
name
());
auto
v
=
op
.
to_value
();
if
(
not
v
.
empty
())
{
os
<<
", "
<<
"migraphx::from_json_string("
<<
enclose_name
(
to_json_string
(
v
))
<<
")"
;
os
<<
"migraphx::make_json_op("
<<
enclose_name
(
op
.
name
());
os
<<
", "
<<
enclose_name
(
to_json_string
(
v
));
}
else
{
os
<<
"migraphx::make_op("
<<
enclose_name
(
op
.
name
());
}
os
<<
")"
;
}
...
...
@@ -939,7 +946,7 @@ module& module::sort()
this
->
move_instruction
(
ins
,
this
->
begin
());
for
(
auto
child
:
ins
->
inputs
())
{
if
(
!
contains
(
this
->
impl
->
instructions
,
child
))
if
(
not
contains
(
this
->
impl
->
instructions
,
child
))
{
continue
;
}
...
...
src/normalize_attributes.cpp
View file @
7f97b8ef
...
...
@@ -79,14 +79,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
if
(
contains
(
vec_attrs
,
op
::
normalize_attribute
::
include_max
))
{
if
(
!
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less_equal
<>
{}))
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less_equal
<>
{}))
{
MIGRAPHX_THROW
(
"TUNE_VECTOR: value out of range!"
);
}
}
else
{
if
(
!
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less
<>
{}))
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less
<>
{}))
{
MIGRAPHX_THROW
(
"TUNE_VECTOR: value out of range!"
);
}
...
...
@@ -118,14 +118,15 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
if
(
contains
(
vec_attrs
,
op
::
normalize_attribute
::
include_min
))
{
if
(
!
std
::
equal
(
min_vals
.
begin
(),
min_vals
.
end
(),
result
.
begin
(),
std
::
less_equal
<>
{}))
if
(
not
std
::
equal
(
min_vals
.
begin
(),
min_vals
.
end
(),
result
.
begin
(),
std
::
less_equal
<>
{}))
{
MIGRAPHX_THROW
(
"TUNE_VECTOR: attribute out of range!"
);
}
}
else
{
if
(
!
std
::
equal
(
result
.
begin
(),
result
.
end
(),
min_vals
.
begin
(),
std
::
less
<>
{}))
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
min_vals
.
begin
(),
std
::
less
<>
{}))
{
MIGRAPHX_THROW
(
"TUNE_VECTOR: attribute out of range!"
);
}
...
...
@@ -174,7 +175,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
tuned
=
true
;
}
}
if
(
!
attrs
.
contains
(
"normalize_axes"
))
if
(
not
attrs
.
contains
(
"normalize_axes"
))
{
return
tuned
;
}
...
...
src/normalize_ops.cpp
View file @
7f97b8ef
...
...
@@ -43,9 +43,9 @@ void normalize_ops::apply(module& m) const
if
(
inputs
.
empty
())
continue
;
auto
lens
=
inputs
[
0
]
->
get_shape
()
.
lens
()
;
auto
s
=
inputs
[
0
]
->
get_shape
();
migraphx
::
operation
tuned_op
=
ins
->
get_operator
();
if
(
normalize_attributes
(
tuned_op
,
lens
))
if
(
normalize_attributes
(
tuned_op
,
s
.
max_
lens
()
))
{
m
.
replace_instruction
(
ins
,
tuned_op
,
inputs
);
ins
->
set_normalized
();
...
...
src/onnx/include/migraphx/onnx/onnx_parser.hpp
View file @
7f97b8ef
...
...
@@ -97,6 +97,7 @@ struct onnx_parser
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
,
0
};
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
shape
::
dynamic_dimension
>>
map_dyn_input_dims
;
bool
use_dyn_output
=
false
;
bool
skip_unknown_operators
=
false
;
int64_t
max_loop_iterations
=
10
;
int64_t
opset_version
=
13
;
...
...
@@ -119,6 +120,7 @@ struct onnx_parser
};
shape
::
type_t
get_type
(
int
dtype
);
bool
is_type_float
(
shape
::
type_t
dtype
);
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/onnx/onnx.cpp
View file @
7f97b8ef
...
...
@@ -60,8 +60,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{
parser
.
default_dyn_dim_value
=
options
.
default_dyn_dim_value
;
}
if
(
not
options
.
map_input_dims
.
empty
()
and
not
options
.
map_dyn_input_dims
.
empty
())
{
MIGRAPHX_THROW
(
"PARSE_ONNX_FROM: both map_input_dims and map_dyn_input_dims non-empty, only"
"one should be used"
);
}
parser
.
skip_unknown_operators
=
options
.
skip_unknown_operators
;
parser
.
max_loop_iterations
=
options
.
max_loop_iterations
;
parser
.
use_dyn_output
=
options
.
use_dyn_output
;
if
(
options
.
print_program_on_error
)
{
...
...
@@ -80,6 +86,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{
parser
.
parse_from
(
std
::
forward
<
Ts
>
(
xs
)...);
}
return
std
::
move
(
parser
.
prog
);
}
...
...
src/onnx/onnx_parser.cpp
View file @
7f97b8ef
...
...
@@ -28,7 +28,6 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/common.hpp>
#include <migraphx/type_traits.hpp>
#include <migraphx/float_equal.hpp>
...
...
@@ -60,7 +59,7 @@ create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
std
::
size_t
(
1
),
std
::
multiplies
<
std
::
size_t
>
());
if
(
elem_num
==
0
)
{
return
{
};
return
literal
{
shape_type
};
}
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
...
...
@@ -77,7 +76,7 @@ static literal create_literal(shape::type_t shape_type, const std::vector<size_t
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
std
::
size_t
(
1
),
std
::
multiplies
<
std
::
size_t
>
());
if
(
elem_num
==
0
)
{
return
{
};
return
literal
{
shape_type
};
}
// scalar input
...
...
@@ -188,7 +187,7 @@ operation onnx_parser::load(const std::string& name, const node_info& info) cons
void
onnx_parser
::
parse_undefined
(
module
*
mod
,
const
std
::
string
&
name
)
{
if
(
!
contains
(
instructions
,
name
))
if
(
not
contains
(
instructions
,
name
))
{
auto
ins
=
mod
->
add_instruction
(
make_op
(
"undefined"
));
instructions
[
name
]
=
ins
;
...
...
@@ -257,11 +256,6 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
void
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
)
{
if
(
not
map_input_dims
.
empty
()
and
not
map_dyn_input_dims
.
empty
())
{
MIGRAPHX_THROW
(
"PARSE_GRAPH: both map_input_dims and map_dyn_input_dims non-empty, only"
"one should be used"
);
}
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
;
for
(
auto
&&
f
:
graph
.
initializer
())
{
...
...
@@ -273,7 +267,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{
const
std
::
string
&
name
=
input
.
name
();
// input not in initializer_data, so it is a real input
if
(
!
contains
(
mod_insts
,
name
))
if
(
not
contains
(
mod_insts
,
name
))
{
// ONNX specification does not specify how to deal with the
// scenario that a nested subgraph contains a parameter with the
...
...
@@ -360,7 +354,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
all_output_names
.
begin
(),
all_output_names
.
end
(),
std
::
back_inserter
(
prog_output_names
),
[
&
](
const
auto
&
name
)
{
return
!
(
name
.
empty
()
or
instructions
.
count
(
name
)
==
0
);
});
[
&
](
const
auto
&
name
)
{
return
not
(
name
.
empty
()
or
instructions
.
count
(
name
)
==
0
);
});
std
::
vector
<
instruction_ref
>
output_ins
;
std
::
transform
(
prog_output_names
.
begin
(),
...
...
@@ -450,7 +444,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
{
shape
::
type_t
shape_type
=
get_type
(
t
.
tensor_type
().
elem_type
());
if
(
!
input_dims
.
empty
())
if
(
not
input_dims
.
empty
())
{
return
{
shape_type
,
input_dims
};
}
...
...
@@ -514,6 +508,16 @@ shape::type_t get_type(int dtype)
}
}
bool
is_type_float
(
shape
::
type_t
dtype
)
{
bool
r
=
false
;
if
(
dtype
==
shape
::
float_type
or
dtype
==
shape
::
double_type
or
dtype
==
shape
::
half_type
)
{
r
=
true
;
}
return
r
;
}
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/padding.cpp
View file @
7f97b8ef
...
...
@@ -42,7 +42,7 @@ void cal_auto_padding_size(onnx_parser::node_info info,
size_t
kdims
=
in_lens
.
size
()
-
2
;
assert
(
k_lens
.
size
()
==
kdims
and
dilation
.
size
()
==
kdims
);
if
(
!
contains
(
info
.
attributes
,
"auto_pad"
))
if
(
not
contains
(
info
.
attributes
,
"auto_pad"
))
{
return
;
}
...
...
@@ -124,7 +124,7 @@ void tune_padding_size(const value& v,
}
// if padding is symmetric, return directly
if
(
!
is_asym_padding
(
padding
))
if
(
not
is_asym_padding
(
padding
))
{
return
;
}
...
...
src/onnx/parse_batchnorm.cpp
View file @
7f97b8ef
...
...
@@ -24,7 +24,7 @@
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
op/batch_norm_inference
.hpp>
#include <migraphx/
instruction
.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -36,28 +36,63 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
const
std
::
vector
<
instruction_ref
>
&
args
)
const
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
float
epsilon
=
1e-5
f
;
float
momentum
=
0.9
f
;
op
::
batch_norm_inference
::
bn_infer_mode_t
bn_mode
=
op
::
batch_norm_inference
::
spatial
;
float
epsilon
=
1e-5
f
;
if
(
contains
(
info
.
attributes
,
"epsilon"
))
{
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"epsilon"
)).
at
<
float
>
();
}
if
(
contains
(
info
.
attributes
,
"momentum"
))
auto
x_lens
=
args
[
0
]
->
get_shape
().
lens
();
auto
x_type
=
args
[
0
]
->
get_shape
().
type
();
if
(
std
::
any_of
(
args
.
cbegin
()
+
1
,
args
.
cend
(),
[](
auto
a
)
{
return
a
->
get_shape
().
lens
().
size
()
!=
1
;
}))
{
MIGRAPHX_THROW
(
"PARSE_BATCHNORM: argument scale, bias, mean, or var rank != 1"
);
}
if
(
x_lens
.
size
()
==
1
)
{
auto
rt
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
0.5
}});
auto
eps
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
epsilon
}});
auto
n0
=
info
.
add_broadcastable_binary_op
(
"sub"
,
args
[
0
],
args
[
3
]);
auto
d0
=
info
.
add_broadcastable_binary_op
(
"add"
,
args
[
4
],
eps
);
auto
d1
=
info
.
add_broadcastable_binary_op
(
"pow"
,
d0
,
rt
);
auto
div0
=
info
.
add_broadcastable_binary_op
(
"div"
,
n0
,
d1
);
auto
r0
=
info
.
add_broadcastable_binary_op
(
"mul"
,
div0
,
args
[
1
]);
return
info
.
add_broadcastable_binary_op
(
"add"
,
r0
,
args
[
2
]);
}
else
if
(
x_lens
.
size
()
>
2
)
{
momentum
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"momentum"
)).
at
<
float
>
();
// unsqueeze tensors of shape (C) to broadcast correctly
std
::
vector
<
int64_t
>
unsqueeze_axes
(
x_lens
.
size
()
-
2
);
std
::
iota
(
unsqueeze_axes
.
begin
(),
unsqueeze_axes
.
end
(),
1
);
auto
rt
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
0.5
}});
auto
eps
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
epsilon
}});
auto
scale_unsqueeze
=
info
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
args
[
1
]);
auto
bias_unsqueeze
=
info
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
args
[
2
]);
auto
mean_unsqueeze
=
info
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
args
[
3
]);
auto
var_unsqueeze
=
info
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
args
[
4
]);
auto
numer
=
info
.
add_broadcastable_binary_op
(
"sub"
,
args
[
0
],
mean_unsqueeze
);
auto
var_eps
=
info
.
add_broadcastable_binary_op
(
"add"
,
var_unsqueeze
,
eps
);
auto
denom
=
info
.
add_broadcastable_binary_op
(
"pow"
,
var_eps
,
rt
);
auto
div0
=
info
.
add_broadcastable_binary_op
(
"div"
,
numer
,
denom
);
auto
r0
=
info
.
add_broadcastable_binary_op
(
"mul"
,
div0
,
scale_unsqueeze
);
return
info
.
add_broadcastable_binary_op
(
"add"
,
r0
,
bias_unsqueeze
);
}
if
(
contains
(
info
.
attributes
,
"spatial"
))
else
{
bn_mode
=
(
parser
.
parse_value
(
info
.
attributes
.
at
(
"spatial"
)).
at
<
uint64_t
>
()
>
0
)
?
op
::
batch_norm_inference
::
spatial
:
op
::
batch_norm_inference
::
per_activation
;
// num dims either 0 or 2
MIGRAPHX_THROW
(
"PARSE_BATCHNORM: rank "
+
std
::
to_string
(
x_lens
.
size
())
+
" input tensor, unhandled data format"
)
;
}
op
::
batch_norm_inference
op
{
epsilon
,
momentum
,
bn_mode
};
return
info
.
add_instruction
(
op
,
args
);
}
};
...
...
src/onnx/parse_cast.cpp
View file @
7f97b8ef
...
...
@@ -38,7 +38,7 @@ struct parse_cast : op_parser<parse_cast>
onnx_parser
::
node_info
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
if
(
!
contains
(
info
.
attributes
,
"to"
))
if
(
not
contains
(
info
.
attributes
,
"to"
))
{
MIGRAPHX_THROW
(
"PARSE_CAST: missing to type attribute!"
);
}
...
...
src/onnx/parse_constant.cpp
View file @
7f97b8ef
...
...
@@ -43,7 +43,7 @@ struct parse_constant : op_parser<parse_constant>
// return empty literal
if
(
v
.
get_shape
().
elements
()
==
0
)
{
return
info
.
add_literal
(
literal
{});
return
info
.
add_literal
(
literal
{
v
.
get_shape
().
type
()
});
}
auto
dim_size
=
info
.
attributes
.
at
(
"value"
).
t
().
dims_size
();
...
...
src/onnx/parse_constant_fill.cpp
View file @
7f97b8ef
...
...
@@ -93,7 +93,7 @@ struct parse_constant_fill : op_parser<parse_constant_fill>
}
else
if
(
input_as_shape
==
0
)
{
if
(
!
contains
(
info
.
attributes
,
"shape"
))
if
(
not
contains
(
info
.
attributes
,
"shape"
))
{
MIGRAPHX_THROW
(
"ConstantFill: attribute output shape is needed"
);
}
...
...
src/onnx/parse_convolution.cpp
View file @
7f97b8ef
...
...
@@ -47,15 +47,17 @@ struct parse_convolution : op_parser<parse_convolution>
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
auto
op
=
make_op
(
opd
.
op_name
);
auto
values
=
op
.
to_value
();
auto
l0
=
args
[
0
];
auto
weights
=
args
[
1
];
auto
in_lens
=
l0
->
get_shape
().
lens
();
auto
op
=
make_op
(
opd
.
op_name
);
auto
values
=
op
.
to_value
();
auto
l0
=
args
[
0
];
auto
weights
=
args
[
1
];
auto
l0_shape
=
l0
->
get_shape
();
auto
w_shape
=
weights
->
get_shape
();
auto
in_lens
=
l0_shape
.
max_lens
();
assert
(
in_lens
.
size
()
>
2
);
auto
kdims
=
in_lens
.
size
()
-
2
;
// ensure pads availabe only when auto_pad is "NOT_SET"
// ensure pads availab
l
e only when auto_pad is "NOT_SET"
check_padding_mode
(
info
,
"CONV"
);
if
(
contains
(
info
.
attributes
,
"strides"
))
...
...
@@ -79,21 +81,65 @@ struct parse_convolution : op_parser<parse_convolution>
copy
(
info
.
attributes
[
"pads"
].
ints
(),
std
::
back_inserter
(
padding
));
check_attr_sizes
(
kdims
,
padding
.
size
()
/
2
,
"PARSE_CONV: inconsistent paddings"
);
}
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
{
auto
weight_lens
=
weights
->
get_shape
().
lens
();
std
::
vector
<
std
::
size_t
>
k_lens
(
weight_lens
.
begin
()
+
2
,
weight_lens
.
end
());
cal_auto_padding_size
(
info
,
values
,
k_lens
,
values
[
"dilation"
].
to_vector
<
std
::
size_t
>
(),
in_lens
,
padding
);
auto
auto_pad
=
info
.
attributes
[
"auto_pad"
].
s
();
bool
is_same_padding
=
false
;
auto
auto_pad
=
info
.
attributes
[
"auto_pad"
].
s
();
if
(
auto_pad
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
values
[
"padding_mode"
]
=
to_value
(
op
::
padding_mode_t
::
same
);
is_same_padding
=
true
;
}
// check if image shape is dynamic
bool
image_shape_dynamic
=
false
;
if
(
l0_shape
.
dynamic
())
{
auto
dyn_dims
=
l0_shape
.
dyn_dims
();
std
::
for_each
(
dyn_dims
.
begin
()
+
2
,
dyn_dims
.
end
(),
[
&
](
auto
dyn_dim
)
{
if
(
not
dyn_dim
.
is_fixed
())
{
image_shape_dynamic
=
true
;
}
});
}
// check if kernel shape is dynamic
bool
kernel_shape_dynamic
=
false
;
if
(
w_shape
.
dynamic
())
{
auto
dyn_dims
=
w_shape
.
dyn_dims
();
std
::
for_each
(
dyn_dims
.
begin
()
+
2
,
dyn_dims
.
end
(),
[
&
](
auto
dyn_dim
)
{
if
(
not
dyn_dim
.
is_fixed
())
{
kernel_shape_dynamic
=
true
;
}
});
}
if
(
is_same_padding
)
{
if
(
image_shape_dynamic
or
kernel_shape_dynamic
)
{
// must calculate "same" padding with input shape data
bool
is_same_upper
=
(
auto_pad
.
find
(
"SAME_UPPER"
)
!=
std
::
string
::
npos
);
values
[
"padding_mode"
]
=
is_same_upper
?
to_value
(
op
::
padding_mode_t
::
same_upper
)
:
to_value
(
op
::
padding_mode_t
::
same_lower
);
values
[
"use_dynamic_same_auto_pad"
]
=
true
;
}
else
{
values
[
"padding_mode"
]
=
to_value
(
op
::
padding_mode_t
::
same
);
// kernel shape will be fixed, so max_lens() == min_len() for kernel lengths
auto
weight_lens
=
weights
->
get_shape
().
max_lens
();
std
::
vector
<
std
::
size_t
>
k_lens
(
weight_lens
.
begin
()
+
2
,
weight_lens
.
end
());
cal_auto_padding_size
(
info
,
values
,
k_lens
,
values
[
"dilation"
].
to_vector
<
std
::
size_t
>
(),
in_lens
,
padding
);
}
}
}
values
[
"padding"
]
=
std
::
vector
<
size_t
>
(
padding
.
begin
(),
padding
.
end
());
...
...
src/onnx/parse_gemm.cpp
View file @
7f97b8ef
...
...
@@ -94,7 +94,7 @@ struct parse_gemm : op_parser<parse_gemm>
out_lens
.
back
()
=
l2
->
get_shape
().
lens
().
back
();
auto
l3
=
args
[
2
];
auto
l3_lens
=
l3
->
get_shape
().
lens
();
if
(
!
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
l3_lens
.
begin
(),
l3_lens
.
end
()))
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
l3_lens
.
begin
(),
l3_lens
.
end
()))
{
l3
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
...
...
src/onnx/parse_generic_op.cpp
View file @
7f97b8ef
...
...
@@ -58,7 +58,6 @@ struct parse_generic_op : op_parser<parse_generic_op>
{
"Log"
,
"log"
},
{
"LRN"
,
"lrn"
},
{
"Neg"
,
"neg"
},
{
"NonMaxSuppression"
,
"nonmaxsuppression"
},
{
"Reciprocal"
,
"recip"
},
{
"Relu"
,
"relu"
},
{
"Round"
,
"round"
},
...
...
@@ -75,7 +74,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
bool
needs_contiguous
(
const
std
::
string
&
op_name
)
const
{
return
contains
({
"flatten"
,
"gather"
,
"nonmaxsuppression"
,
"scatter"
},
op_name
);
return
contains
({
"flatten"
,
"gather"
,
"scatter"
},
op_name
);
}
instruction_ref
parse
(
const
op_desc
&
opd
,
...
...
src/onnx/parse_if.cpp
View file @
7f97b8ef
...
...
@@ -47,7 +47,8 @@ struct parse_if : op_parser<parse_if>
if
(
args
.
front
()
->
get_shape
().
elements
()
!=
1
)
{
MIGRAPHX_THROW
(
"PARSE_IF: condition input can have only one element!"
);
MIGRAPHX_THROW
(
"PARSE_IF: "
+
info
.
name
+
" condition input can have only one element!"
);
}
std
::
string
then_name
=
info
.
name
+
"_if"
;
...
...
@@ -69,7 +70,8 @@ struct parse_if : op_parser<parse_if>
else_out_shapes
.
begin
(),
else_out_shapes
.
end
()))
{
MIGRAPHX_THROW
(
"PARSE_IF: then and else sub_grahps must have same output shapes!"
);
MIGRAPHX_THROW
(
"PARSE_IF: "
+
info
.
name
+
" then and else sub_grahps must have same output shapes!"
);
}
auto
if_ret
=
info
.
add_instruction
(
make_op
(
"if"
),
args
,
{
then_mdl
,
else_mdl
});
...
...
src/onnx/parse_instancenorm.cpp
View file @
7f97b8ef
...
...
@@ -32,9 +32,12 @@ namespace onnx {
struct
parse_instancenorm
:
op_parser
<
parse_instancenorm
>
{
const
std
::
set
<
shape
::
type_t
>
valid_types
=
{
shape
::
float_type
,
shape
::
half_type
,
shape
::
double_type
};
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"InstanceNormalization"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*
opd
*/
,
instruction_ref
parse
(
const
op_desc
&
opd
,
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
...
...
@@ -52,6 +55,11 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto
scale
=
args
[
1
];
auto
bias
=
args
[
2
];
auto
dims
=
x
->
get_shape
().
lens
();
auto
dtype
=
x
->
get_shape
().
type
();
if
(
not
contains
(
valid_types
,
dtype
))
MIGRAPHX_THROW
(
opd
.
op_name
+
": invalid output type: "
+
std
::
to_string
(
dtype
)
+
". Valid types are 1 (float), 10 (half), and 11 (double)."
);
auto
ndims
=
dims
.
size
();
assert
(
ndims
>=
2
);
auto
kdims
=
ndims
-
2
;
...
...
@@ -65,7 +73,7 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto
l0
=
info
.
add_instruction
(
make_op
(
"sqdiff"
),
x
,
mean_bcast
);
auto
variance
=
info
.
add_instruction
(
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
l0
);
auto
l1
=
info
.
add_instruction
(
make_op
(
"sub"
),
x
,
mean_bcast
);
auto
epsilon_literal
=
info
.
add_literal
(
epsilon
);
auto
epsilon_literal
=
info
.
add_literal
(
literal
{
shape
{
dtype
},
{
epsilon
}}
);
auto
epsilon_bcast
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
epsilon_literal
);
auto
variance_bcast
=
...
...
Prev
1
2
3
4
5
6
7
8
9
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment