Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4a39a0f7
Commit
4a39a0f7
authored
Oct 11, 2021
by
Shucai Xiao
Browse files
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into add-conv_bn_add-test
parents
5564172e
bb827865
Changes
542
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
397 additions
and
192 deletions
+397
-192
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+4
-0
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+24
-61
src/onnx/padding.cpp
src/onnx/padding.cpp
+14
-5
src/onnx/parse_binary_op.cpp
src/onnx/parse_binary_op.cpp
+2
-1
src/onnx/parse_clip.cpp
src/onnx/parse_clip.cpp
+2
-2
src/onnx/parse_convolution.cpp
src/onnx/parse_convolution.cpp
+1
-1
src/onnx/parse_depthtospace.cpp
src/onnx/parse_depthtospace.cpp
+74
-0
src/onnx/parse_dequantizelinear.cpp
src/onnx/parse_dequantizelinear.cpp
+28
-25
src/onnx/parse_expand.cpp
src/onnx/parse_expand.cpp
+2
-2
src/onnx/parse_gather_elements.cpp
src/onnx/parse_gather_elements.cpp
+2
-2
src/onnx/parse_gemm.cpp
src/onnx/parse_gemm.cpp
+34
-10
src/onnx/parse_generic_op.cpp
src/onnx/parse_generic_op.cpp
+3
-1
src/onnx/parse_if.cpp
src/onnx/parse_if.cpp
+28
-45
src/onnx/parse_imagescalar.cpp
src/onnx/parse_imagescalar.cpp
+1
-1
src/onnx/parse_instancenorm.cpp
src/onnx/parse_instancenorm.cpp
+6
-6
src/onnx/parse_loop.cpp
src/onnx/parse_loop.cpp
+72
-0
src/onnx/parse_matmul.cpp
src/onnx/parse_matmul.cpp
+5
-6
src/onnx/parse_multinomial.cpp
src/onnx/parse_multinomial.cpp
+65
-0
src/onnx/parse_nonzero.cpp
src/onnx/parse_nonzero.cpp
+25
-20
src/onnx/parse_onehot.cpp
src/onnx/parse_onehot.cpp
+5
-4
No files found.
src/onnx/onnx.cpp
View file @
4a39a0f7
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <iostream>
#include <iostream>
#include <fstream>
#include <fstream>
#include <unordered_map>
#include <unordered_map>
...
@@ -20,6 +21,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
...
@@ -20,6 +21,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
parser
.
map_input_dims
=
options
.
map_input_dims
;
parser
.
map_input_dims
=
options
.
map_input_dims
;
parser
.
default_dim_value
=
options
.
default_dim_value
;
parser
.
default_dim_value
=
options
.
default_dim_value
;
parser
.
skip_unknown_operators
=
options
.
skip_unknown_operators
;
parser
.
skip_unknown_operators
=
options
.
skip_unknown_operators
;
parser
.
max_loop_iterations
=
options
.
max_loop_iterations
;
if
(
options
.
print_program_on_error
)
if
(
options
.
print_program_on_error
)
{
{
...
@@ -57,5 +59,7 @@ program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options
...
@@ -57,5 +59,7 @@ program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options
return
parse_onnx_from
(
options
,
data
,
size
);
return
parse_onnx_from
(
options
,
data
,
size
);
}
}
std
::
vector
<
std
::
string
>
get_onnx_operators
()
{
return
onnx
::
get_op_parsers
();
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/onnx/onnx_parser.cpp
View file @
4a39a0f7
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/common.hpp>
#include <migraphx/type_traits.hpp>
#include <migraphx/type_traits.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/file_buffer.hpp>
...
@@ -84,73 +85,18 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
...
@@ -84,73 +85,18 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
auto
bias_bcast
=
mod
->
add_instruction
(
auto
bias_bcast
=
mod
->
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"
dim
s"
,
curr_ins
->
get_shape
().
lens
()}}),
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"
out_len
s"
,
curr_ins
->
get_shape
().
lens
()}}),
args
[
2
]);
args
[
2
]);
return
mod
->
add_instruction
(
make_op
(
"add"
),
curr_ins
,
bias_bcast
);
return
mod
->
add_instruction
(
make_op
(
"add"
),
curr_ins
,
bias_bcast
);
}
}
return
curr_ins
;
return
curr_ins
;
}
}
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
)
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
if
(
s0
.
size
()
>
s1
.
size
())
{
s0
.
swap
(
s1
);
}
std
::
vector
<
std
::
size_t
>
out_lens
(
s1
);
auto
offset
=
s1
.
size
()
-
s0
.
size
();
std
::
transform
(
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
if
(
a
!=
b
and
a
!=
1
and
b
!=
1
)
{
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTLEN: shape {"
+
to_string_range
(
s0
)
+
"} and {"
+
to_string_range
(
s1
)
+
"} mismatch!"
);
}
return
std
::
max
(
a
,
b
);
});
return
out_lens
;
}
instruction_ref
onnx_parser
::
node_info
::
add_broadcastable_binary_op
(
const
std
::
string
&
op_name
,
instruction_ref
onnx_parser
::
node_info
::
add_broadcastable_binary_op
(
const
std
::
string
&
op_name
,
instruction_ref
arg0
,
instruction_ref
arg0
,
instruction_ref
arg1
)
const
instruction_ref
arg1
)
const
{
{
if
(
arg0
->
get_shape
().
lens
()
!=
arg1
->
get_shape
().
lens
())
return
add_common_op
(
*
mod
,
make_op
(
op_name
),
{
arg0
,
arg1
});
{
// Get lengths for both arguments
auto
s0
=
arg0
->
get_shape
().
lens
();
auto
s1
=
arg1
->
get_shape
().
lens
();
auto
out_lens
=
compute_broadcasted_lens
(
s0
,
s1
);
auto
l0
=
arg0
;
if
(
arg0
->
get_shape
().
lens
()
!=
out_lens
)
l0
=
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"output_lens"
,
out_lens
}}),
arg0
);
auto
l1
=
arg1
;
if
(
arg1
->
get_shape
().
lens
()
!=
out_lens
)
l1
=
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"output_lens"
,
out_lens
}}),
arg1
);
return
add_instruction
(
make_op
(
op_name
),
l0
,
l1
);
}
else
{
return
add_instruction
(
make_op
(
op_name
),
{
arg0
,
arg1
});
}
}
}
instruction_ref
instruction_ref
...
@@ -278,28 +224,42 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
...
@@ -278,28 +224,42 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
void
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
)
void
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
)
{
{
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
;
for
(
auto
&&
f
:
graph
.
initializer
())
for
(
auto
&&
f
:
graph
.
initializer
())
{
{
instructions
[
f
.
name
()]
=
mod
->
add_literal
(
parse_tensor
(
f
));
// backup instructions in parent mod
mod_insts
[
f
.
name
()]
=
mod
->
add_literal
(
parse_tensor
(
f
));
}
}
for
(
auto
&&
input
:
graph
.
input
())
for
(
auto
&&
input
:
graph
.
input
())
{
{
const
std
::
string
&
name
=
input
.
name
();
const
std
::
string
&
name
=
input
.
name
();
// input not in initializer_data, so it is a real input
// input not in initializer_data, so it is a real input
if
(
!
contains
(
inst
ruction
s
,
name
))
if
(
!
contains
(
mod_
insts
,
name
))
{
{
// ONNX specification does not specify hwo to deal with the
// scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that.
if
(
contains
(
instructions
,
name
))
{
MIGRAPHX_THROW
(
"module
\"
"
+
mod
->
name
()
+
"
\"
has parameter name
\"
"
+
name
+
"
\"
existing in parent graph!"
);
}
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
size_t
>
dims
;
if
(
map_input_dims
.
count
(
name
)
>
0
)
if
(
map_input_dims
.
count
(
name
)
>
0
)
{
{
dims
=
map_input_dims
.
at
(
name
);
dims
=
map_input_dims
.
at
(
name
);
}
}
shape
s
=
parse_type
(
input
.
type
(),
dims
);
shape
s
=
parse_type
(
input
.
type
(),
dims
);
inst
ruction
s
[
name
]
=
mod
->
add_parameter
(
name
,
s
);
mod_
insts
[
name
]
=
mod
->
add_parameter
(
name
,
s
);
}
}
}
}
std
::
copy
(
mod_insts
.
begin
(),
mod_insts
.
end
(),
std
::
inserter
(
instructions
,
instructions
.
end
()));
for
(
auto
&&
node
:
graph
.
node
())
for
(
auto
&&
node
:
graph
.
node
())
{
{
std
::
vector
<
instruction_ref
>
args
;
std
::
vector
<
instruction_ref
>
args
;
...
@@ -363,6 +323,9 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
...
@@ -363,6 +323,9 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
// add the return instuction
// add the return instuction
mod
->
add_return
(
output_ins
);
mod
->
add_return
(
output_ins
);
// remove instructions added in this mod
erase_if
(
instructions
,
[
&
](
auto
&&
p
)
{
return
mod
->
has_instruction
(
p
.
second
);
});
}
}
literal
onnx_parser
::
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
literal
onnx_parser
::
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
...
...
src/onnx/padding.cpp
View file @
4a39a0f7
...
@@ -126,7 +126,7 @@ void check_asym_padding(const onnx_parser::node_info& info,
...
@@ -126,7 +126,7 @@ void check_asym_padding(const onnx_parser::node_info& info,
auto
left_pad_it
=
padding
.
begin
();
auto
left_pad_it
=
padding
.
begin
();
auto
right_pad_it
=
left_pad_it
+
pad_ndims
;
auto
right_pad_it
=
left_pad_it
+
pad_ndims
;
if
(
is_asym_padding
(
padding
)
or
count_include_pad
==
1
)
if
(
count_include_pad
==
1
)
{
{
std
::
vector
<
int64_t
>
asym_pads
{
0
,
0
,
0
,
0
};
// don't pad N and C
std
::
vector
<
int64_t
>
asym_pads
{
0
,
0
,
0
,
0
};
// don't pad N and C
// add left pads
// add left pads
...
@@ -134,10 +134,19 @@ void check_asym_padding(const onnx_parser::node_info& info,
...
@@ -134,10 +134,19 @@ void check_asym_padding(const onnx_parser::node_info& info,
// add right pads
// add right pads
asym_pads
.
insert
(
asym_pads
.
begin
()
+
pad_ndims
+
4
,
right_pad_it
,
padding
.
end
());
asym_pads
.
insert
(
asym_pads
.
begin
()
+
pad_ndims
+
4
,
right_pad_it
,
padding
.
end
());
ins
=
info
.
add_instruction
(
make_op
(
"pad"
,
{{
"pads"
,
asym_pads
},
{
"value"
,
pad_val
}}),
ins
);
ins
=
info
.
add_instruction
(
make_op
(
"pad"
,
{{
"pads"
,
asym_pads
},
{
"value"
,
pad_val
}}),
ins
);
}
std
::
vector
<
size_t
>
new_padding
(
padding
.
size
());
else
// subtract asym padding originally found from parsing the operator
{
std
::
transform
(
padding
.
begin
(),
v
[
"padding"
]
=
std
::
vector
<
size_t
>
(
left_pad_it
,
right_pad_it
);
left_pad_it
,
asym_pads
.
begin
()
+
2
,
new_padding
.
begin
(),
std
::
minus
<
size_t
>
());
std
::
transform
(
right_pad_it
,
padding
.
end
(),
asym_pads
.
begin
()
+
pad_ndims
+
4
,
new_padding
.
begin
()
+
pad_ndims
,
std
::
minus
<
size_t
>
());
v
[
"padding"
]
=
new_padding
;
}
}
}
}
...
...
src/onnx/parse_binary_op.cpp
View file @
4a39a0f7
...
@@ -36,7 +36,8 @@ struct parse_binary_op : op_parser<parse_binary_op>
...
@@ -36,7 +36,8 @@ struct parse_binary_op : op_parser<parse_binary_op>
{
{
uint64_t
axis
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
();
uint64_t
axis
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
();
auto
l
=
info
.
add_instruction
(
auto
l
=
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"dims"
,
args
[
0
]
->
get_shape
().
lens
()}}),
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
args
[
0
]
->
get_shape
().
lens
()}}),
args
[
1
]);
args
[
1
]);
return
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
args
[
0
],
l
);
return
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
args
[
0
],
l
);
}
}
...
...
src/onnx/parse_clip.cpp
View file @
4a39a0f7
...
@@ -47,13 +47,13 @@ struct parse_clip : op_parser<parse_clip>
...
@@ -47,13 +47,13 @@ struct parse_clip : op_parser<parse_clip>
if
(
min_used
)
if
(
min_used
)
{
{
min_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
input_lens
}}),
min_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
min_arg
);
min_arg
);
}
}
if
(
max_used
)
if
(
max_used
)
{
{
max_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
input_lens
}}),
max_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
max_arg
);
max_arg
);
}
}
...
...
src/onnx/parse_convolution.cpp
View file @
4a39a0f7
...
@@ -73,7 +73,7 @@ struct parse_convolution : op_parser<parse_convolution>
...
@@ -73,7 +73,7 @@ struct parse_convolution : op_parser<parse_convolution>
values
[
"padding_mode"
]
=
to_value
(
op
::
padding_mode_t
::
same
);
values
[
"padding_mode"
]
=
to_value
(
op
::
padding_mode_t
::
same
);
}
}
}
}
check_asym_padding
(
info
,
l0
,
padding
,
values
);
values
[
"padding"
]
=
std
::
vector
<
size_t
>
(
padding
.
begin
()
,
padding
.
end
()
);
if
(
contains
(
info
.
attributes
,
"group"
))
if
(
contains
(
info
.
attributes
,
"group"
))
{
{
...
...
src/onnx/parse_depthtospace.cpp
0 → 100644
View file @
4a39a0f7
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_depthtospace
:
op_parser
<
parse_depthtospace
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"DepthToSpace"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
/*parser*/
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
auto
s
=
args
[
0
]
->
get_shape
();
// mode attribute of DepthToSpace
auto
mode
=
std
::
string
(
"DCR"
);
if
(
contains
(
info
.
attributes
,
"mode"
))
{
mode
=
info
.
attributes
.
at
(
"mode"
).
s
();
// DCR or CRD?
}
// blocksize attribute of DepthToSpace
int
blocksize
=
0
;
if
(
contains
(
info
.
attributes
,
"blocksize"
))
{
blocksize
=
info
.
attributes
.
at
(
"blocksize"
).
i
();
}
if
(
blocksize
<
1
)
{
MIGRAPHX_THROW
(
"DepthToSpace: blocksize is less than 1"
);
}
// calculate dimensions
auto
lens1
=
s
.
lens
();
auto
lens2
=
s
.
lens
();
unsigned
long
divisor
=
std
::
pow
(
blocksize
,
2
);
if
((
lens2
[
1
]
%
divisor
)
==
0
)
lens2
[
1
]
=
lens2
[
1
]
/
divisor
;
else
MIGRAPHX_THROW
(
"DepthToSpace: div by blocksize quotient not int "
);
lens1
.
push_back
(
lens1
[
2
]);
lens1
.
push_back
(
lens1
[
3
]);
lens2
[
2
]
=
lens2
[
2
]
*
blocksize
;
lens2
[
3
]
=
lens2
[
3
]
*
blocksize
;
lens1
[
2
]
=
blocksize
;
std
::
vector
<
int64_t
>
perm
;
if
(
mode
==
"DCR"
)
{
lens1
[
3
]
=
lens1
[
1
]
/
divisor
;
lens1
[
1
]
=
blocksize
;
perm
=
{
0
,
3
,
4
,
1
,
5
,
2
};
}
else
if
(
mode
==
"CRD"
)
{
lens1
[
1
]
=
lens1
[
1
]
/
divisor
;
lens1
[
3
]
=
blocksize
;
perm
=
{
0
,
1
,
4
,
2
,
5
,
3
};
}
else
MIGRAPHX_THROW
(
"DepthToSpace: mode attribute cannot be read."
);
auto
temp1
=
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
lens1
}}),
args
[
0
]);
auto
temp2
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
temp1
);
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
lens2
}}),
info
.
make_contiguous
(
temp2
));
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/parse_dequantizelinear.cpp
View file @
4a39a0f7
...
@@ -15,46 +15,49 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
...
@@ -15,46 +15,49 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
instruction_ref
parse
(
const
op_desc
&
opd
,
instruction_ref
parse
(
const
op_desc
&
opd
,
const
onnx_parser
&
/*parser*/
,
const
onnx_parser
&
/*parser*/
,
const
onnx_parser
::
node_info
&
info
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
const
std
::
vector
<
instruction_ref
>
&
args
)
const
{
{
int
axis
=
1
;
int
axis
=
1
;
if
(
contains
(
info
.
attributes
,
"axis"
))
if
(
contains
(
info
.
attributes
,
"axis"
))
axis
=
info
.
attributes
.
at
(
"axis"
).
i
();
axis
=
info
.
attributes
.
at
(
"axis"
).
i
();
auto
input_lens
=
args
[
0
]
->
get_shape
().
lens
();
auto
input_lens
=
args
[
0
]
->
get_shape
().
lens
();
int
n_dim
=
static_cast
<
int
>
(
input_lens
.
size
()
)
;
auto
n_dim
=
input_lens
.
size
();
auto
sub_zero_point
=
args
[
0
];
instruction_ref
x_scale
;
if
(
args
[
1
]
->
get_shape
().
elements
()
!=
1
)
{
auto
tuned_axis
=
tune_axis
(
n_dim
,
axis
,
opd
.
op_name
);
x_scale
=
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
tuned_axis
},
{
"out_lens"
,
input_lens
}}),
args
[
1
]);
}
else
{
x_scale
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
args
[
1
]);
}
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
auto
zero_point
=
args
[
2
];
auto
x_
zero_point
=
args
[
2
];
if
(
not
(
zero_point
->
get_shape
().
elements
()
=
=
1
)
)
if
(
x_
zero_point
->
get_shape
().
elements
()
!
=
1
)
{
{
axis
=
tune_axis
(
n_dim
,
axis
,
opd
.
op_name
);
auto
tuned_axis
=
tune_axis
(
n_dim
,
axis
,
opd
.
op_name
);
zero_point
=
info
.
add_instruction
(
x_zero_point
=
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"dims"
,
input_lens
}}),
zero_point
);
make_op
(
"broadcast"
,
{{
"axis"
,
tuned_axis
},
{
"out_lens"
,
input_lens
}}),
x_zero_point
);
}
else
{
x_zero_point
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
x_zero_point
);
}
}
auto
zero_point_int32
=
info
.
add_instruction
(
return
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
int32_type
}}),
zero_point
);
make_op
(
"dequantizelinear"
),
args
[
0
],
x_scale
,
x_zero_point
);
auto
sub_zero_point_int32
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
int32_type
}}),
sub_zero_point
);
sub_zero_point
=
info
.
add_broadcastable_binary_op
(
"sub"
,
sub_zero_point_int32
,
zero_point_int32
);
}
}
auto
dequant_input
=
info
.
add_instruction
(
return
info
.
add_instruction
(
make_op
(
"dequantizelinear"
),
args
[
0
],
x_scale
);
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
float_type
}}),
sub_zero_point
);
auto
scale
=
args
[
1
];
if
(
not
(
scale
->
get_shape
().
elements
()
==
1
))
{
axis
=
tune_axis
(
n_dim
,
axis
,
opd
.
op_name
);
scale
=
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"dims"
,
input_lens
}}),
scale
);
}
return
info
.
add_broadcastable_binary_op
(
"mul"
,
dequant_input
,
scale
);
}
}
};
};
...
...
src/onnx/parse_expand.cpp
View file @
4a39a0f7
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -23,8 +24,7 @@ struct parse_expand : op_parser<parse_expand>
...
@@ -23,8 +24,7 @@ struct parse_expand : op_parser<parse_expand>
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
size_t
>
dims
;
arg_s
.
visit
([
&
](
auto
input
)
{
dims
.
assign
(
input
.
begin
(),
input
.
end
());
});
arg_s
.
visit
([
&
](
auto
input
)
{
dims
.
assign
(
input
.
begin
(),
input
.
end
());
});
auto
out_lens
=
compute_broadcasted_lens
(
in_lens
,
dims
);
auto
out_lens
=
compute_broadcasted_lens
(
in_lens
,
dims
);
return
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"output_lens"
,
out_lens
}}),
return
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
0
]);
args
[
0
]);
}
}
};
};
...
...
src/onnx/parse_gather_elements.cpp
View file @
4a39a0f7
...
@@ -63,8 +63,8 @@ struct parse_gather_elements : op_parser<parse_gather_elements>
...
@@ -63,8 +63,8 @@ struct parse_gather_elements : op_parser<parse_gather_elements>
info
.
add_literal
(
literal
(
ind_s
,
data_indices
.
begin
(),
data_indices
.
end
()));
info
.
add_literal
(
literal
(
ind_s
,
data_indices
.
begin
(),
data_indices
.
end
()));
auto
l_dim_idx
=
info
.
add_literal
(
literal
(
ind_s
,
vec_axis_ind
.
begin
(),
vec_axis_ind
.
end
()));
auto
l_dim_idx
=
info
.
add_literal
(
literal
(
ind_s
,
vec_axis_ind
.
begin
(),
vec_axis_ind
.
end
()));
auto
l_stride
=
info
.
add_literal
(
literal
{{
ind_s
.
type
(),
{
1
}},
{
axis_stride
}});
auto
l_stride
=
info
.
add_literal
(
literal
{{
ind_s
.
type
(),
{
1
}},
{
axis_stride
}});
l_stride
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"output_lens"
,
ind_s
.
lens
()}}),
l_stride
=
l_stride
);
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
ind_s
.
lens
()}}),
l_stride
);
auto
dim_diff
=
info
.
add_instruction
(
make_op
(
"sub"
),
arg_ind
,
l_dim_idx
);
auto
dim_diff
=
info
.
add_instruction
(
make_op
(
"sub"
),
arg_ind
,
l_dim_idx
);
auto
delta
=
info
.
add_instruction
(
make_op
(
"mul"
),
dim_diff
,
l_stride
);
auto
delta
=
info
.
add_instruction
(
make_op
(
"mul"
),
dim_diff
,
l_stride
);
auto
ind
=
info
.
add_instruction
(
make_op
(
"add"
),
l_shape_idx
,
delta
);
auto
ind
=
info
.
add_instruction
(
make_op
(
"add"
),
l_shape_idx
,
delta
);
...
...
src/onnx/parse_gemm.cpp
View file @
4a39a0f7
...
@@ -42,13 +42,30 @@ struct parse_gemm : op_parser<parse_gemm>
...
@@ -42,13 +42,30 @@ struct parse_gemm : op_parser<parse_gemm>
// swap the last two elements
// swap the last two elements
std
::
swap
(
*
perm
.
rbegin
(),
*
(
perm
.
rbegin
()
+
1
));
std
::
swap
(
*
perm
.
rbegin
(),
*
(
perm
.
rbegin
()
+
1
));
auto
l1
=
(
transa
)
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"dims"
,
perm
}}),
args
[
0
])
auto
l1
=
args
[
0
];
:
args
[
0
];
auto
dot_type
=
l1
->
get_shape
().
type
();
auto
l2
=
(
transb
)
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"dims"
,
perm
}}),
args
[
1
])
:
args
[
1
];
if
(
alpha
!=
1.0
f
)
{
auto
alpha_literal
=
info
.
add_literal
(
alpha
);
l1
=
info
.
add_broadcastable_binary_op
(
"mul"
,
alpha_literal
,
l1
);
if
(
l1
->
get_shape
().
type
()
!=
dot_type
)
{
l1
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
l1
);
}
}
l1
=
(
transa
)
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
l1
)
:
l1
;
auto
l2
=
(
transb
)
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
1
])
:
args
[
1
];
auto
ret
=
info
.
add_instruction
(
make_op
(
"dot"
),
l1
,
l2
);
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
if
(
beta
!=
0.
f
&&
args
[
2
]
->
get_shape
().
elements
()
>
0
)
if
(
not
float_equal
(
beta
,
0.0
f
)
&&
args
[
2
]
->
get_shape
().
elements
()
>
0
)
{
{
auto
out_lens
=
l1
->
get_shape
().
lens
();
auto
out_lens
=
l1
->
get_shape
().
lens
();
out_lens
.
back
()
=
l2
->
get_shape
().
lens
().
back
();
out_lens
.
back
()
=
l2
->
get_shape
().
lens
().
back
();
...
@@ -56,15 +73,22 @@ struct parse_gemm : op_parser<parse_gemm>
...
@@ -56,15 +73,22 @@ struct parse_gemm : op_parser<parse_gemm>
auto
l3_lens
=
l3
->
get_shape
().
lens
();
auto
l3_lens
=
l3
->
get_shape
().
lens
();
if
(
!
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
l3_lens
.
begin
(),
l3_lens
.
end
()))
if
(
!
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
l3_lens
.
begin
(),
l3_lens
.
end
()))
{
{
l3
=
info
.
add_instruction
(
l3
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
make_op
(
"multibroadcast"
,
{{
"output_lens"
,
out_lens
}}),
args
[
2
]);
args
[
2
]);
}
}
return
info
.
add_instruction
(
auto
beta_literal
=
info
.
add_literal
(
beta
);
make_op
(
"dot"
,
{{
"alpha"
,
alpha
},
{
"beta"
,
beta
}}),
l1
,
l2
,
l3
);
auto
beta_l3
=
info
.
add_broadcastable_binary_op
(
"mul"
,
l3
,
beta_literal
);
if
(
beta_l3
->
get_shape
().
type
()
!=
dot_type
)
{
beta_l3
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
beta_l3
);
}
return
info
.
add_instruction
(
make_op
(
"add"
),
ret
,
beta_l3
);
}
}
}
}
return
info
.
add_instruction
(
make_op
(
"dot"
,
{{
"alpha"
,
alpha
},
{
"beta"
,
beta
}}),
l1
,
l2
)
;
return
ret
;
}
}
};
};
...
...
src/onnx/parse_generic_op.cpp
View file @
4a39a0f7
...
@@ -35,6 +35,8 @@ struct parse_generic_op : op_parser<parse_generic_op>
...
@@ -35,6 +35,8 @@ struct parse_generic_op : op_parser<parse_generic_op>
{
"Reciprocal"
,
"recip"
},
{
"Reciprocal"
,
"recip"
},
{
"Relu"
,
"relu"
},
{
"Relu"
,
"relu"
},
{
"Round"
,
"round"
},
{
"Round"
,
"round"
},
{
"Scatter"
,
"scatter"
},
{
"ScatterElements"
,
"scatter"
},
{
"Sigmoid"
,
"sigmoid"
},
{
"Sigmoid"
,
"sigmoid"
},
{
"Sign"
,
"sign"
},
{
"Sign"
,
"sign"
},
{
"Sin"
,
"sin"
},
{
"Sin"
,
"sin"
},
...
@@ -47,7 +49,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
...
@@ -47,7 +49,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
bool
needs_contiguous
(
const
std
::
string
&
op_name
)
const
bool
needs_contiguous
(
const
std
::
string
&
op_name
)
const
{
{
return
contains
({
"
gath
er"
},
op_name
);
return
contains
({
"
flatten"
,
"gather"
,
"scatt
er"
},
op_name
);
}
}
instruction_ref
parse
(
const
op_desc
&
opd
,
instruction_ref
parse
(
const
op_desc
&
opd
,
...
...
src/onnx/parse_if.cpp
View file @
4a39a0f7
#include <migraphx/instruction_ref.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/checks.hpp>
...
@@ -26,59 +27,41 @@ struct parse_if : op_parser<parse_if>
...
@@ -26,59 +27,41 @@ struct parse_if : op_parser<parse_if>
MIGRAPHX_THROW
(
"PARSE_IF: condition input can have only one element!"
);
MIGRAPHX_THROW
(
"PARSE_IF: condition input can have only one element!"
);
}
}
migraphx
::
argument
cond_arg
=
args
.
front
()
->
eval
();
std
::
string
then_name
=
info
.
name
+
"_if"
;
// cond is not constant, need to create sub_modules
module_ref
then_mdl
=
parser
.
prog
.
create_module
(
then_name
);
if
(
cond_arg
.
empty
())
{
std
::
string
then_name
=
info
.
name
+
"_if"
;
module_ref
then_mdl
=
parser
.
prog
.
create_module
(
then_name
);
std
::
string
else_name
=
info
.
name
+
"_else"
;
module_ref
else_mdl
=
parser
.
prog
.
create_module
(
else_name
);
// parse the then sub_graph
std
::
string
else_name
=
info
.
name
+
"_else"
;
parser
.
parse_graph
(
then_mdl
,
then_graph
);
module_ref
else_mdl
=
parser
.
prog
.
create_module
(
else_name
);
// parse
_
the
else
sub_graph
// parse
the
then
sub_graph
parser
.
parse_graph
(
else
_mdl
,
else
_graph
);
parser
.
parse_graph
(
then
_mdl
,
then
_graph
);
auto
then_out_shapes
=
then_mdl
->
get_output_shapes
();
// parse_the else sub_graph
auto
else_out_shapes
=
else_mdl
->
get_output_shapes
();
parser
.
parse_graph
(
else_mdl
,
else_graph
);
if
(
not
std
::
equal
(
then_out_shapes
.
begin
(),
then_out_shapes
.
end
(),
else_out_shapes
.
begin
(),
else_out_shapes
.
end
()))
{
MIGRAPHX_THROW
(
"PARSE_IF: then and else sub_grahps must have same output shapes!"
);
}
auto
ret
=
info
.
add_instruction
(
make_op
(
"if"
),
args
,
{
then_mdl
,
else_mdl
});
auto
then_out_shapes
=
then_mdl
->
get_output_shapes
();
auto
else_out_shapes
=
else_mdl
->
get_output_shapes
();
return
{
ret
};
if
(
not
std
::
equal
(
then_out_shapes
.
begin
(),
}
then_out_shapes
.
end
(),
else
else_out_shapes
.
begin
(),
else_out_shapes
.
end
()))
{
{
auto
*
mod
=
info
.
mod
;
MIGRAPHX_THROW
(
"PARSE_IF: then and else sub_grahps must have same output shapes!"
);
// then branch
}
if
(
cond_arg
.
at
<
bool
>
())
{
parser
.
parse_graph
(
mod
,
then_graph
);
}
// else branch
else
{
parser
.
parse_graph
(
mod
,
else_graph
);
}
// inputs of the return instruction are that of the output of the
auto
if_ret
=
info
.
add_instruction
(
make_op
(
"if"
),
args
,
{
then_mdl
,
else_mdl
});
// if instruction
auto
out_s
=
if_ret
->
get_shape
();
instruction_ref
ret_ins
=
std
::
prev
(
mod
->
end
());
assert
(
out_s
.
type
()
==
shape
::
tuple_type
);
auto
outputs
=
ret_ins
->
inputs
();
assert
(
ret_ins
->
name
()
==
"@return"
);
mod
->
remove_instruction
(
ret_ins
);
return
outputs
;
const
auto
&
vec_shapes
=
out_s
.
sub_shapes
();
std
::
vector
<
instruction_ref
>
out_inss
;
for
(
std
::
size_t
i
=
0
;
i
<
vec_shapes
.
size
();
++
i
)
{
auto
ret
=
info
.
add_instruction
(
make_op
(
"get_tuple_elem"
,
{{
"index"
,
i
}}),
if_ret
);
out_inss
.
push_back
(
ret
);
}
}
return
out_inss
;
}
}
};
};
...
...
src/onnx/parse_imagescalar.cpp
View file @
4a39a0f7
...
@@ -40,7 +40,7 @@ struct parse_imagescalar : op_parser<parse_imagescalar>
...
@@ -40,7 +40,7 @@ struct parse_imagescalar : op_parser<parse_imagescalar>
auto
img_scaled
=
auto
img_scaled
=
info
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
args
.
front
(),
scale_tensor
);
info
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
args
.
front
(),
scale_tensor
);
auto
bias_bcast
=
info
.
add_instruction
(
auto
bias_bcast
=
info
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"
dim
s"
,
input_lens
}}),
bias_vals
);
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"
out_len
s"
,
input_lens
}}),
bias_vals
);
return
info
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
img_scaled
,
bias_bcast
);
return
info
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
img_scaled
,
bias_bcast
);
}
}
};
};
...
...
src/onnx/parse_instancenorm.cpp
View file @
4a39a0f7
...
@@ -38,23 +38,23 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
...
@@ -38,23 +38,23 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto
mean
=
info
.
add_instruction
(
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
x
);
auto
mean
=
info
.
add_instruction
(
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
x
);
auto
mean_bcast
=
auto
mean_bcast
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
dims
}}),
mean
);
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
mean
);
auto
l0
=
info
.
add_instruction
(
make_op
(
"sqdiff"
),
x
,
mean_bcast
);
auto
l0
=
info
.
add_instruction
(
make_op
(
"sqdiff"
),
x
,
mean_bcast
);
auto
variance
=
info
.
add_instruction
(
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
l0
);
auto
variance
=
info
.
add_instruction
(
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
l0
);
auto
l1
=
info
.
add_instruction
(
make_op
(
"sub"
),
x
,
mean_bcast
);
auto
l1
=
info
.
add_instruction
(
make_op
(
"sub"
),
x
,
mean_bcast
);
auto
epsilon_literal
=
info
.
add_literal
(
epsilon
);
auto
epsilon_literal
=
info
.
add_literal
(
epsilon
);
auto
epsilon_bcast
=
info
.
add_instruction
(
auto
epsilon_bcast
=
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
dims
}}),
epsilon_literal
);
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
epsilon_literal
);
auto
variance_bcast
=
auto
variance_bcast
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
dims
}}),
variance
);
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
variance
);
auto
l2
=
info
.
add_instruction
(
make_op
(
"add"
),
variance_bcast
,
epsilon_bcast
);
auto
l2
=
info
.
add_instruction
(
make_op
(
"add"
),
variance_bcast
,
epsilon_bcast
);
auto
l3
=
info
.
add_instruction
(
make_op
(
"rsqrt"
),
l2
);
auto
l3
=
info
.
add_instruction
(
make_op
(
"rsqrt"
),
l2
);
auto
l4
=
info
.
add_instruction
(
make_op
(
"mul"
),
l1
,
l3
);
auto
l4
=
info
.
add_instruction
(
make_op
(
"mul"
),
l1
,
l3
);
auto
scale_bcast
=
auto
scale_bcast
=
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"
dim
s"
,
dims
}}),
scale
);
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"
out_len
s"
,
dims
}}),
scale
);
;
;
auto
bias_bcast
=
auto
bias_bcast
=
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"
dim
s"
,
dims
}}),
bias
);
info
.
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"
out_len
s"
,
dims
}}),
bias
);
auto
l5
=
info
.
add_instruction
(
make_op
(
"mul"
),
l4
,
scale_bcast
);
auto
l5
=
info
.
add_instruction
(
make_op
(
"mul"
),
l4
,
scale_bcast
);
return
info
.
add_instruction
(
make_op
(
"add"
),
l5
,
bias_bcast
);
return
info
.
add_instruction
(
make_op
(
"add"
),
l5
,
bias_bcast
);
}
}
...
...
src/onnx/parse_loop.cpp
0 → 100644
View file @
4a39a0f7
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_loop
:
op_parser
<
parse_loop
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Loop"
}};
}
std
::
vector
<
instruction_ref
>
parse
(
const
op_desc
&
/*opd*/
,
onnx_parser
&
parser
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
// default value of the max_iter_num
int64_t
max_iterations
=
parser
.
max_loop_iterations
;
// iteration input is empty
if
(
args
.
at
(
0
)
->
name
()
==
"undefined"
)
{
shape
iter_s
{
shape
::
int64_type
};
args
[
0
]
=
info
.
add_literal
(
literal
(
iter_s
,
{
max_iterations
}));
}
else
{
auto
arg_iters
=
args
.
at
(
0
)
->
eval
();
if
(
not
arg_iters
.
empty
())
{
max_iterations
=
arg_iters
.
at
<
int64_t
>
();
}
}
// condition input is empty
if
(
args
.
at
(
1
)
->
name
()
==
"undefined"
)
{
shape
cond_s
{
shape
::
bool_type
};
args
[
1
]
=
info
.
add_literal
(
literal
(
cond_s
,
{
true
}));
}
// retrieve the subgraph
const
auto
&
sub_graph
=
info
.
attributes
.
at
(
"body"
).
g
();
std
::
string
mod_name
=
info
.
name
+
"_loop"
;
module_ref
sub_mod
=
parser
.
prog
.
create_module
(
mod_name
);
// parse the sub_graph
parser
.
parse_graph
(
sub_mod
,
sub_graph
);
auto
ret
=
info
.
add_instruction
(
make_op
(
"loop"
,
{{
"max_iterations"
,
max_iterations
}}),
args
,
{
sub_mod
});
auto
out_s
=
ret
->
get_shape
();
assert
(
out_s
.
type
()
==
shape
::
tuple_type
);
const
auto
&
vec_shapes
=
out_s
.
sub_shapes
();
std
::
vector
<
instruction_ref
>
out_inss
;
for
(
std
::
size_t
i
=
0
;
i
<
vec_shapes
.
size
();
++
i
)
{
auto
r
=
info
.
add_instruction
(
make_op
(
"get_tuple_elem"
,
{{
"index"
,
i
}}),
ret
);
out_inss
.
push_back
(
r
);
}
return
out_inss
;
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/parse_matmul.cpp
View file @
4a39a0f7
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -57,18 +58,16 @@ struct parse_matmul : op_parser<parse_matmul>
...
@@ -57,18 +58,16 @@ struct parse_matmul : op_parser<parse_matmul>
if
(
l0_lens
!=
l0_broadcasted_lens
)
if
(
l0_lens
!=
l0_broadcasted_lens
)
{
{
bl0
=
info
.
add_instruction
(
bl0
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
l0_broadcasted_lens
}}),
l0
);
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l0_broadcasted_lens
}}),
l0
);
}
}
if
(
l1_lens
!=
l1_broadcasted_lens
)
if
(
l1_lens
!=
l1_broadcasted_lens
)
{
{
bl1
=
info
.
add_instruction
(
bl1
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
l1_broadcasted_lens
}}),
l1
);
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l1_broadcasted_lens
}}),
l1
);
}
}
}
}
instruction_ref
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
bl0
,
bl1
);
auto
dot_res
=
int64_t
num_axis
=
static_cast
<
int64_t
>
(
dot_res
->
get_shape
().
lens
().
size
());
info
.
add_instruction
(
make_op
(
opd
.
op_name
,
{{
"alpha"
,
1
},
{
"beta"
,
0
}}),
bl0
,
bl1
);
int64_t
num_axis
=
static_cast
<
int64_t
>
(
dot_res
->
get_shape
().
lens
().
size
());
if
(
is_a_prepended
)
if
(
is_a_prepended
)
{
{
dot_res
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
num_axis
-
2
}}}),
dot_res
);
dot_res
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
num_axis
-
2
}}}),
dot_res
);
...
...
src/onnx/parse_multinomial.cpp
0 → 100644
View file @
4a39a0f7
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <random>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_multinomial
:
op_parser
<
parse_multinomial
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Multinomial"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
/*parser*/
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
int
dtype
=
6
;
if
(
contains
(
info
.
attributes
,
"dtype"
))
dtype
=
info
.
attributes
.
at
(
"dtype"
).
i
();
shape
::
type_t
output_type
=
get_type
(
dtype
);
size_t
sample_size
=
1
;
if
(
contains
(
info
.
attributes
,
"sample_size"
))
sample_size
=
info
.
attributes
.
at
(
"sample_size"
).
i
();
float
seed
=
static_cast
<
float
>
(
std
::
chrono
::
high_resolution_clock
::
now
().
time_since_epoch
().
count
());
if
(
contains
(
info
.
attributes
,
"seed"
))
seed
=
info
.
attributes
.
at
(
"seed"
).
f
();
// Subtract the per-batch maximum log-probability, making the per-batch max 0
auto
maxes
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reduce_max"
,
{{
"axes"
,
{
1
}}}),
args
[
0
]);
auto
mb_maxes
=
info
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
args
[
0
]
->
get_shape
().
lens
()}}),
maxes
);
auto
cdf
=
info
.
add_instruction
(
migraphx
::
make_op
(
"sub"
),
args
[
0
],
mb_maxes
);
// Take the element-wise exponent to get probabilities in the range (0, 1]
cdf
=
info
.
add_instruction
(
migraphx
::
make_op
(
"exp"
),
cdf
);
// Compute the cumulative density function
cdf
=
info
.
add_instruction
(
migraphx
::
make_op
(
"prefix_scan_sum"
,
{{
"axis"
,
1
},
{
"exclusive"
,
false
}}),
cdf
);
// Pre-compute random distribution
std
::
mt19937
gen
(
seed
);
std
::
uniform_real_distribution
<>
dis
(
0.0
,
1.0
);
size_t
batch_size
=
args
[
0
]
->
get_shape
().
lens
().
front
();
migraphx
::
shape
dist_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
sample_size
}};
std
::
vector
<
float
>
random_dist
(
batch_size
*
sample_size
);
std
::
generate
(
random_dist
.
begin
(),
random_dist
.
end
(),
[
&
]()
{
return
dis
(
gen
);
});
auto
dist_lit
=
info
.
add_literal
(
migraphx
::
literal
{
dist_shape
,
random_dist
});
return
info
.
add_instruction
(
migraphx
::
make_op
(
"multinomial"
,
{{
"dtype"
,
output_type
}}),
cdf
,
dist_lit
);
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/parse_nonzero.cpp
View file @
4a39a0f7
...
@@ -9,7 +9,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -9,7 +9,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
onnx
{
namespace
onnx
{
template
<
class
T
>
template
<
class
T
>
std
::
vector
<
std
::
size_t
>
nonzero_indices
(
const
std
::
vector
<
T
>&
data
)
static
std
::
vector
<
std
::
size_t
>
nonzero_indices
(
const
std
::
vector
<
T
>&
data
)
{
{
std
::
vector
<
std
::
size_t
>
indices
;
std
::
vector
<
std
::
size_t
>
indices
;
for
(
std
::
size_t
i
=
0
;
i
<
data
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
data
.
size
();
++
i
)
...
@@ -31,30 +31,35 @@ struct parse_nonzero : op_parser<parse_nonzero>
...
@@ -31,30 +31,35 @@ struct parse_nonzero : op_parser<parse_nonzero>
std
::
vector
<
instruction_ref
>
args
)
const
std
::
vector
<
instruction_ref
>
args
)
const
{
{
migraphx
::
argument
data_arg
=
args
.
back
()
->
eval
();
migraphx
::
argument
data_arg
=
args
.
back
()
->
eval
();
check_arg_empty
(
data_arg
,
"PARSE_NONZERO: cannot support non-constant input!"
);
if
(
data_arg
.
empty
())
{
std
::
vector
<
std
::
size_t
>
indices
;
return
info
.
add_instruction
(
make_op
(
"nonzero"
),
args
);
data_arg
.
visit
([
&
](
auto
val
)
{
}
using
val_type
=
std
::
remove_cv_t
<
typename
decltype
(
val
)
::
value_type
>
;
else
std
::
vector
<
val_type
>
vec_data
;
{
vec_data
.
assign
(
val
.
begin
(),
val
.
end
());
std
::
vector
<
std
::
size_t
>
indices
;
indices
=
nonzero_indices
(
vec_data
);
data_arg
.
visit
([
&
](
auto
val
)
{
});
using
val_type
=
std
::
remove_cv_t
<
typename
decltype
(
val
)
::
value_type
>
;
std
::
vector
<
val_type
>
vec_data
;
vec_data
.
assign
(
val
.
begin
(),
val
.
end
());
indices
=
nonzero_indices
(
vec_data
);
});
shape
in_s
=
args
[
0
]
->
get_shape
();
shape
in_s
=
args
[
0
]
->
get_shape
();
shape
out_s
{
shape
::
int64_type
,
{
in_s
.
lens
().
size
(),
indices
.
size
()}};
shape
out_s
{
shape
::
int64_type
,
{
in_s
.
lens
().
size
(),
indices
.
size
()}};
std
::
vector
<
int64_t
>
out_data
(
out_s
.
elements
());
std
::
vector
<
int64_t
>
out_data
(
out_s
.
elements
());
for
(
std
::
size_t
i
=
0
;
i
<
indices
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
indices
.
size
();
++
i
)
{
auto
idx
=
in_s
.
multi
(
indices
[
i
]);
for
(
std
::
size_t
j
=
0
;
j
<
in_s
.
lens
().
size
();
++
j
)
{
{
out_data
[
out_s
.
index
({
j
,
i
})]
=
idx
[
j
];
auto
idx
=
in_s
.
multi
(
indices
[
i
]);
for
(
std
::
size_t
j
=
0
;
j
<
in_s
.
lens
().
size
();
++
j
)
{
out_data
[
out_s
.
index
({
j
,
i
})]
=
idx
[
j
];
}
}
}
}
return
info
.
add_literal
(
literal
(
out_s
,
out_data
));
return
info
.
add_literal
(
literal
(
out_s
,
out_data
));
}
}
}
};
};
...
...
src/onnx/parse_onehot.cpp
View file @
4a39a0f7
...
@@ -45,8 +45,9 @@ struct parse_onehot : op_parser<parse_onehot>
...
@@ -45,8 +45,9 @@ struct parse_onehot : op_parser<parse_onehot>
std
::
vector
<
int64_t
>
perm
(
n_rank
-
1
);
std
::
vector
<
int64_t
>
perm
(
n_rank
-
1
);
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
0
);
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
0
);
perm
.
insert
(
perm
.
begin
()
+
tuned_axis
,
n_rank
-
1
);
perm
.
insert
(
perm
.
begin
()
+
tuned_axis
,
n_rank
-
1
);
auto
tr_out
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"dims"
,
perm
}}),
gather_out
);
auto
tr_out
=
auto
lens
=
tr_out
->
get_shape
().
lens
();
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
gather_out
);
auto
lens
=
tr_out
->
get_shape
().
lens
();
auto
off_val
=
info
.
add_instruction
(
auto
off_val
=
info
.
add_instruction
(
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
2
]);
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
2
]);
...
@@ -54,9 +55,9 @@ struct parse_onehot : op_parser<parse_onehot>
...
@@ -54,9 +55,9 @@ struct parse_onehot : op_parser<parse_onehot>
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
args
[
2
]);
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
args
[
2
]);
auto
diff
=
info
.
add_instruction
(
make_op
(
"sub"
),
on_val
,
off_val
);
auto
diff
=
info
.
add_instruction
(
make_op
(
"sub"
),
on_val
,
off_val
);
auto
unsq_off_val
=
auto
unsq_off_val
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
lens
}}),
off_val
);
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
off_val
);
auto
unsq_diff_val
=
auto
unsq_diff_val
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out
put
_lens"
,
lens
}}),
diff
);
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
diff
);
auto
l_mul
=
info
.
add_instruction
(
make_op
(
"mul"
),
tr_out
,
unsq_diff_val
);
auto
l_mul
=
info
.
add_instruction
(
make_op
(
"mul"
),
tr_out
,
unsq_diff_val
);
return
info
.
add_instruction
(
make_op
(
"add"
),
l_mul
,
unsq_off_val
);
return
info
.
add_instruction
(
make_op
(
"add"
),
l_mul
,
unsq_off_val
);
}
}
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment