Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
870a396b
"megatron/git@developer.sourcefind.cn:wuxk1/megatron-lm.git" did not exist on "a7a12f823fe19d89ccaf58877b14290be85c66d9"
Commit
870a396b
authored
Jan 23, 2023
by
Khalique Ahmed
Browse files
manual merge
parents
228b665c
d309e02f
Changes
473
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
474 additions
and
163 deletions
+474
-163
src/layout_nhwc.cpp
src/layout_nhwc.cpp
+23
-0
src/load_save.cpp
src/load_save.cpp
+0
-1
src/module.cpp
src/module.cpp
+94
-3
src/onnx/conv.cpp
src/onnx/conv.cpp
+1
-1
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+31
-8
src/onnx/parse_batchnorm.cpp
src/onnx/parse_batchnorm.cpp
+50
-14
src/onnx/parse_binary_op.cpp
src/onnx/parse_binary_op.cpp
+6
-0
src/onnx/parse_convolution.cpp
src/onnx/parse_convolution.cpp
+0
-2
src/onnx/parse_deconvolution.cpp
src/onnx/parse_deconvolution.cpp
+5
-1
src/onnx/parse_gemm.cpp
src/onnx/parse_gemm.cpp
+49
-36
src/onnx/parse_matmul.cpp
src/onnx/parse_matmul.cpp
+58
-34
src/onnx/parse_pad.cpp
src/onnx/parse_pad.cpp
+6
-0
src/onnx/parse_pooling.cpp
src/onnx/parse_pooling.cpp
+82
-38
src/onnx/parse_reduce_op.cpp
src/onnx/parse_reduce_op.cpp
+1
-2
src/onnx/parse_reshape.cpp
src/onnx/parse_reshape.cpp
+1
-1
src/onnx/parse_split.cpp
src/onnx/parse_split.cpp
+18
-6
src/onnx/parse_transpose.cpp
src/onnx/parse_transpose.cpp
+1
-1
src/opt/memory_coloring_impl.cpp
src/opt/memory_coloring_impl.cpp
+7
-7
src/pad_calc.cpp
src/pad_calc.cpp
+33
-8
src/pass_manager.cpp
src/pass_manager.cpp
+8
-0
No files found.
src/layout_nhwc.cpp
View file @
870a396b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/module.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
...
...
src/load_save.cpp
View file @
870a396b
...
@@ -25,7 +25,6 @@
...
@@ -25,7 +25,6 @@
#include <migraphx/file_buffer.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/json.hpp>
#include <migraphx/json.hpp>
#include <migraphx/msgpack.hpp>
#include <migraphx/msgpack.hpp>
#include <migraphx/file_buffer.hpp>
#include <fstream>
#include <fstream>
namespace
migraphx
{
namespace
migraphx
{
...
...
src/module.cpp
View file @
870a396b
...
@@ -34,7 +34,6 @@
...
@@ -34,7 +34,6 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/json.hpp>
#include <migraphx/json.hpp>
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
...
@@ -385,9 +384,13 @@ instruction_ref module::move_instruction(instruction_ref src, instruction_ref ds
...
@@ -385,9 +384,13 @@ instruction_ref module::move_instruction(instruction_ref src, instruction_ref ds
instruction_ref
module
::
move_instructions
(
instruction_ref
src
,
instruction_ref
dst
)
instruction_ref
module
::
move_instructions
(
instruction_ref
src
,
instruction_ref
dst
)
{
{
this
->
move_instruction
(
src
,
dst
);
for
(
auto
ins
:
src
->
inputs
())
for
(
auto
ins
:
src
->
inputs
())
this
->
move_instruction
(
ins
,
src
);
{
if
(
not
contains
(
this
->
impl
->
instructions
,
ins
))
continue
;
this
->
move_instructions
(
ins
,
dst
);
}
this
->
move_instruction
(
src
,
dst
);
return
src
;
return
src
;
}
}
...
@@ -786,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
...
@@ -786,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
return
to_c_id
(
"x_"
+
replace_string
(
name
,
":"
,
"_module_"
));
return
to_c_id
(
"x_"
+
replace_string
(
name
,
":"
,
"_module_"
));
}
}
static
void
print_py_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
auto
v
=
op
.
to_value
();
os
<<
"migraphx.op("
<<
enclose_name
(
op
.
name
());
auto
default_values
=
make_op
(
op
.
name
()).
to_value
();
for
(
auto
&&
x
:
v
)
{
auto
name
=
x
.
get_key
();
if
(
default_values
[
name
]
==
x
)
continue
;
os
<<
", "
<<
name
<<
"="
<<
to_json_string
(
x
.
without_key
());
}
os
<<
")"
;
}
static
void
print_make_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
static
void
print_make_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
{
auto
v
=
op
.
to_value
();
auto
v
=
op
.
to_value
();
...
@@ -801,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op)
...
@@ -801,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op)
os
<<
")"
;
os
<<
")"
;
}
}
static
void
print_py_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
{
os
<<
"migraphx.shape("
<<
s
.
type_string
()
<<
", lens="
<<
to_json_string
(
s
.
lens
());
if
(
not
s
.
standard
())
os
<<
", strides="
<<
to_json_string
(
s
.
strides
());
os
<<
")"
;
}
static
void
print_cpp_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
static
void
print_cpp_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
{
{
os
<<
"migraphx::shape{migraphx::shape::"
<<
s
.
type_string
();
os
<<
"migraphx::shape{migraphx::shape::"
<<
s
.
type_string
();
...
@@ -810,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
...
@@ -810,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
os
<<
"}"
;
os
<<
"}"
;
}
}
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
module
::
print_py
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
)
const
{
// cppcheck-suppress variableScope
unsigned
long
seed
=
names
.
size
();
auto
last
=
std
::
prev
(
this
->
end
());
names
=
this
->
print
(
[
&
](
auto
ins
,
auto
ins_names
)
{
std
::
vector
<
std
::
string
>
input_vars
;
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
input_vars
),
[
&
](
auto
input
)
{
return
cpp_var_name
(
ins_names
.
at
(
input
));
});
if
(
ins
!=
last
)
os
<<
cpp_var_name
(
ins_names
.
at
(
ins
))
<<
" = "
;
if
(
ins
->
name
()
==
"@literal"
)
{
os
<<
mname
<<
".add_literal("
;
bool
use_abs
=
false
;
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
use_abs
=
std
::
none_of
(
v
.
begin
(),
v
.
end
(),
[](
auto
x
)
{
return
x
<
0
;
});
});
// Disable abs for now
use_abs
=
false
;
if
(
use_abs
)
os
<<
"migraphx.abs_literal("
;
os
<<
"migraphx.generate_literal("
;
print_py_shape
(
os
,
ins
->
get_shape
());
os
<<
", "
<<
seed
<<
")"
;
if
(
use_abs
)
os
<<
")"
;
os
<<
")"
<<
std
::
endl
;
seed
++
;
}
else
if
(
ins
->
name
()
==
"@param"
)
{
std
::
string
name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
os
<<
mname
<<
".add_parameter("
<<
enclose_name
(
name
)
<<
","
;
print_py_shape
(
os
,
ins
->
get_shape
());
os
<<
")"
<<
std
::
endl
;
}
else
if
(
ins
->
name
()
==
"@return"
)
{
os
<<
mname
<<
".add_return(["
<<
join_strings
(
input_vars
,
", "
)
<<
"])"
<<
std
::
endl
;
}
else
{
assert
(
ins
->
name
().
front
()
!=
'@'
);
os
<<
mname
<<
".add_instruction("
;
print_py_op
(
os
,
ins
->
get_operator
());
os
<<
", ["
<<
join_strings
(
input_vars
,
", "
)
<<
"]"
;
os
<<
")"
<<
std
::
endl
;
}
},
names
);
return
names
;
}
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
module
::
print_cpp
(
std
::
ostream
&
os
,
module
::
print_cpp
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
const
std
::
string
&
mname
,
...
@@ -871,6 +960,8 @@ module::print_cpp(std::ostream& os,
...
@@ -871,6 +960,8 @@ module::print_cpp(std::ostream& os,
return
names
;
return
names
;
}
}
void
module
::
print_py
(
std
::
ostream
&
os
)
const
{
this
->
print_py
(
os
,
this
->
name
(),
{});
}
void
module
::
print_cpp
(
std
::
ostream
&
os
)
const
{
this
->
print_cpp
(
os
,
this
->
name
(),
{});
}
void
module
::
print_cpp
(
std
::
ostream
&
os
)
const
{
this
->
print_cpp
(
os
,
this
->
name
(),
{});
}
void
module
::
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
void
module
::
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
...
...
src/onnx/conv.cpp
View file @
870a396b
...
@@ -30,7 +30,7 @@ namespace onnx {
...
@@ -30,7 +30,7 @@ namespace onnx {
void
recalc_conv_attributes
(
value
&
v
,
size_t
kdims
)
void
recalc_conv_attributes
(
value
&
v
,
size_t
kdims
)
{
{
if
(
not
(
v
[
"padding"
].
size
()
=
=
kdims
or
v
[
"padding"
].
size
()
=
=
kdims
*
2
)
)
if
(
v
[
"padding"
].
size
()
!
=
kdims
and
v
[
"padding"
].
size
()
!
=
kdims
*
2
)
{
{
v
[
"padding"
].
resize
(
kdims
);
v
[
"padding"
].
resize
(
kdims
);
std
::
fill_n
(
v
[
"padding"
].
begin
(),
kdims
,
0
);
std
::
fill_n
(
v
[
"padding"
].
begin
(),
kdims
,
0
);
...
...
src/onnx/onnx_parser.cpp
View file @
870a396b
...
@@ -110,9 +110,19 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
...
@@ -110,9 +110,19 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
{
{
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
auto
bias_bcast
=
mod
->
add_instruction
(
instruction_ref
bias_bcast
;
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
curr_ins
->
get_shape
().
lens
()}}),
// if curr_ins has a dynamic output shape use 2 input broadcast
args
[
2
]);
if
(
curr_ins
->
get_shape
().
dynamic
())
{
bias_bcast
=
mod
->
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
axis
}}),
args
[
2
],
curr_ins
);
}
else
{
bias_bcast
=
mod
->
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
curr_ins
->
get_shape
().
lens
()}}),
args
[
2
]);
}
return
mod
->
add_instruction
(
make_op
(
"add"
),
curr_ins
,
bias_bcast
);
return
mod
->
add_instruction
(
make_op
(
"add"
),
curr_ins
,
bias_bcast
);
}
}
return
curr_ins
;
return
curr_ins
;
...
@@ -393,18 +403,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
...
@@ -393,18 +403,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
literal
onnx_parser
::
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
literal
onnx_parser
::
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
{
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
if
(
not
t
.
external_data
().
empty
())
auto
type
=
get_type
(
t
.
data_type
());
shape
tensor_shape
(
type
,
dims
);
auto
external_data
=
t
.
external_data
();
if
(
not
external_data
.
empty
())
{
{
const
std
::
string
&
data_file
=
t
.
external_data
().
at
(
0
).
value
();
const
std
::
string
&
data_file
=
external_data
.
at
(
0
).
value
();
auto
raw_buffer
=
read_buffer
(
path
+
"/"
+
data_file
);
size_t
num_data_fields
=
external_data
.
size
();
size_t
offset
=
0
;
size_t
nbytes
=
tensor_shape
.
bytes
();
if
(
num_data_fields
>
1
)
// if offset field is present
{
offset
=
std
::
stoul
(
t
.
external_data
().
at
(
1
).
value
());
}
if
(
num_data_fields
>
2
)
// if nbytes field is present
{
nbytes
=
std
::
stoul
(
t
.
external_data
().
at
(
2
).
value
());
}
auto
raw_buffer
=
read_buffer
(
path
+
"/"
+
data_file
,
offset
,
nbytes
);
std
::
string
s
(
raw_buffer
.
begin
(),
raw_buffer
.
end
());
std
::
string
s
(
raw_buffer
.
begin
(),
raw_buffer
.
end
());
auto
type
=
get_type
(
t
.
data_type
());
return
create_literal
(
type
,
dims
,
s
.
data
());
return
create_literal
(
type
,
dims
,
s
.
data
());
}
}
if
(
t
.
has_raw_data
())
if
(
t
.
has_raw_data
())
{
{
const
std
::
string
&
s
=
t
.
raw_data
();
const
std
::
string
&
s
=
t
.
raw_data
();
auto
type
=
get_type
(
t
.
data_type
());
return
create_literal
(
type
,
dims
,
s
.
data
());
return
create_literal
(
type
,
dims
,
s
.
data
());
}
}
...
...
src/onnx/parse_batchnorm.cpp
View file @
870a396b
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
op/batch_norm_inference
.hpp>
#include <migraphx/
instruction
.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -36,28 +36,64 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
...
@@ -36,28 +36,64 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
parser
,
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
const
onnx_parser
::
node_info
&
info
,
const
std
::
vector
<
instruction_ref
>
&
args
)
const
std
::
vector
<
instruction_ref
>
args
)
const
{
{
float
epsilon
=
1e-5
f
;
float
epsilon
=
1e-5
f
;
float
momentum
=
0.9
f
;
op
::
batch_norm_inference
::
bn_infer_mode_t
bn_mode
=
op
::
batch_norm_inference
::
spatial
;
if
(
contains
(
info
.
attributes
,
"epsilon"
))
if
(
contains
(
info
.
attributes
,
"epsilon"
))
{
{
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"epsilon"
)).
at
<
float
>
();
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"epsilon"
)).
at
<
float
>
();
}
}
if
(
contains
(
info
.
attributes
,
"momentum"
))
auto
x_lens
=
args
[
0
]
->
get_shape
().
max_lens
();
auto
x_type
=
args
[
0
]
->
get_shape
().
type
();
if
(
std
::
any_of
(
args
.
cbegin
()
+
1
,
args
.
cend
(),
[](
auto
a
)
{
return
a
->
get_shape
().
lens
().
size
()
!=
1
;
}))
{
MIGRAPHX_THROW
(
"PARSE_BATCHNORM: argument scale, bias, mean, or var rank != 1"
);
}
auto
x_rank
=
x_lens
.
size
();
if
(
x_rank
==
1
or
x_rank
==
2
)
{
auto
rt
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
0.5
}});
auto
eps
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
epsilon
}});
auto
numer
=
info
.
add_broadcastable_binary_op
(
"sub"
,
args
[
0
],
args
[
3
]);
auto
var_eps
=
info
.
add_broadcastable_binary_op
(
"add"
,
args
[
4
],
eps
);
auto
denom
=
info
.
add_broadcastable_binary_op
(
"pow"
,
var_eps
,
rt
);
auto
div0
=
info
.
add_broadcastable_binary_op
(
"div"
,
numer
,
denom
);
auto
r0
=
info
.
add_broadcastable_binary_op
(
"mul"
,
div0
,
args
[
1
]);
return
info
.
add_broadcastable_binary_op
(
"add"
,
r0
,
args
[
2
]);
}
else
if
(
x_rank
>
2
)
{
{
momentum
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"momentum"
)).
at
<
float
>
();
// unsqueeze tensors of shape (C) to broadcast correctly
std
::
vector
<
int64_t
>
unsqueeze_axes
(
x_lens
.
size
()
-
2
);
std
::
iota
(
unsqueeze_axes
.
begin
(),
unsqueeze_axes
.
end
(),
1
);
auto
rt
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
0.5
}});
auto
eps
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
epsilon
}});
auto
scale_unsqueeze
=
info
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
args
[
1
]);
auto
bias_unsqueeze
=
info
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
args
[
2
]);
auto
mean_unsqueeze
=
info
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
args
[
3
]);
auto
var_unsqueeze
=
info
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
args
[
4
]);
auto
numer
=
info
.
add_broadcastable_binary_op
(
"sub"
,
args
[
0
],
mean_unsqueeze
);
auto
var_eps
=
info
.
add_broadcastable_binary_op
(
"add"
,
var_unsqueeze
,
eps
);
auto
denom
=
info
.
add_broadcastable_binary_op
(
"pow"
,
var_eps
,
rt
);
auto
div0
=
info
.
add_broadcastable_binary_op
(
"div"
,
numer
,
denom
);
auto
r0
=
info
.
add_broadcastable_binary_op
(
"mul"
,
div0
,
scale_unsqueeze
);
return
info
.
add_broadcastable_binary_op
(
"add"
,
r0
,
bias_unsqueeze
);
}
}
if
(
contains
(
info
.
attributes
,
"spatial"
))
else
{
{
bn_mode
=
(
parser
.
parse_value
(
info
.
attributes
.
at
(
"spatial"
)).
at
<
uint64_t
>
()
>
0
)
// rank ==
0
?
op
::
batch_norm_inference
::
spatial
MIGRAPHX_THROW
(
"PARSE_BATCHNORM: rank "
+
std
::
to_string
(
x_lens
.
size
())
+
:
op
::
batch_norm_inference
::
per_activation
;
" input tensor, unhandled data format"
)
;
}
}
op
::
batch_norm_inference
op
{
epsilon
,
momentum
,
bn_mode
};
return
info
.
add_instruction
(
op
,
args
);
}
}
};
};
...
...
src/onnx/parse_binary_op.cpp
View file @
870a396b
...
@@ -57,6 +57,12 @@ struct parse_binary_op : op_parser<parse_binary_op>
...
@@ -57,6 +57,12 @@ struct parse_binary_op : op_parser<parse_binary_op>
parser
.
parse_value
(
info
.
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
parser
.
parse_value
(
info
.
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcasted
!=
0
)
if
(
broadcasted
!=
0
)
{
{
if
(
std
::
any_of
(
args
.
cbegin
(),
args
.
cend
(),
[](
auto
a
)
{
return
a
->
get_shape
().
dynamic
();
}))
{
MIGRAPHX_THROW
(
"Binary op broadcast attribute not supported for dynamic input shapes"
);
}
uint64_t
axis
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
();
uint64_t
axis
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
();
auto
l
=
info
.
add_instruction
(
auto
l
=
info
.
add_instruction
(
make_op
(
"broadcast"
,
make_op
(
"broadcast"
,
...
...
src/onnx/parse_convolution.cpp
View file @
870a396b
...
@@ -125,11 +125,9 @@ struct parse_convolution : op_parser<parse_convolution>
...
@@ -125,11 +125,9 @@ struct parse_convolution : op_parser<parse_convolution>
values
[
"padding_mode"
]
=
is_same_upper
values
[
"padding_mode"
]
=
is_same_upper
?
to_value
(
op
::
padding_mode_t
::
same_upper
)
?
to_value
(
op
::
padding_mode_t
::
same_upper
)
:
to_value
(
op
::
padding_mode_t
::
same_lower
);
:
to_value
(
op
::
padding_mode_t
::
same_lower
);
values
[
"use_dynamic_same_auto_pad"
]
=
true
;
}
}
else
else
{
{
values
[
"padding_mode"
]
=
to_value
(
op
::
padding_mode_t
::
same
);
// kernel shape will be fixed, so max_lens() == min_len() for kernel lengths
// kernel shape will be fixed, so max_lens() == min_len() for kernel lengths
auto
weight_lens
=
weights
->
get_shape
().
max_lens
();
auto
weight_lens
=
weights
->
get_shape
().
max_lens
();
std
::
vector
<
std
::
size_t
>
k_lens
(
weight_lens
.
begin
()
+
2
,
weight_lens
.
end
());
std
::
vector
<
std
::
size_t
>
k_lens
(
weight_lens
.
begin
()
+
2
,
weight_lens
.
end
());
...
...
src/onnx/parse_deconvolution.cpp
View file @
870a396b
...
@@ -95,6 +95,8 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
...
@@ -95,6 +95,8 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
check_attr_sizes
(
check_attr_sizes
(
kdims
,
values
[
"dilation"
].
size
(),
"PARSE_CONV_TRANSPOSE: inconsistent dilations"
);
kdims
,
values
[
"dilation"
].
size
(),
"PARSE_CONV_TRANSPOSE: inconsistent dilations"
);
}
}
// TODO: auto padding needs to be implemented for this parser and operator
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
{
{
auto
s
=
info
.
attributes
[
"auto_pad"
].
s
();
auto
s
=
info
.
attributes
[
"auto_pad"
].
s
();
...
@@ -106,7 +108,9 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
...
@@ -106,7 +108,9 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
if
(
s
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
if
(
s
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
{
values
[
"padding_mode"
]
=
to_value
(
op
::
padding_mode_t
::
same
);
bool
is_same_upper
=
(
s
.
find
(
"SAME_UPPER"
)
!=
std
::
string
::
npos
);
values
[
"padding_mode"
]
=
is_same_upper
?
to_value
(
op
::
padding_mode_t
::
same_upper
)
:
to_value
(
op
::
padding_mode_t
::
same_lower
);
}
}
}
}
...
...
src/onnx/parse_gemm.cpp
View file @
870a396b
...
@@ -39,10 +39,19 @@ struct parse_gemm : op_parser<parse_gemm>
...
@@ -39,10 +39,19 @@ struct parse_gemm : op_parser<parse_gemm>
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
std
::
vector
<
instruction_ref
>
args
)
const
{
{
float
alpha
=
1.0
f
;
auto
a_arg
=
args
[
0
];
float
beta
=
1.0
f
;
auto
b_arg
=
args
[
1
];
bool
transa
=
false
;
if
(
a_arg
->
get_shape
().
ndim
()
!=
2
or
b_arg
->
get_shape
().
ndim
()
!=
2
)
bool
transb
=
false
;
{
MIGRAPHX_THROW
(
"PARSE_GEMM: A and B should be rank 2, A is rank "
+
std
::
to_string
(
a_arg
->
get_shape
().
ndim
())
+
", B is rank "
+
std
::
to_string
(
b_arg
->
get_shape
().
ndim
()));
}
float
alpha
=
1.0
f
;
float
beta
=
1.0
f
;
bool
trans_a
=
false
;
bool
trans_b
=
false
;
if
(
contains
(
info
.
attributes
,
"alpha"
))
if
(
contains
(
info
.
attributes
,
"alpha"
))
{
{
alpha
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"alpha"
)).
at
<
float
>
();
alpha
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"alpha"
)).
at
<
float
>
();
...
@@ -53,61 +62,65 @@ struct parse_gemm : op_parser<parse_gemm>
...
@@ -53,61 +62,65 @@ struct parse_gemm : op_parser<parse_gemm>
}
}
if
(
contains
(
info
.
attributes
,
"transA"
))
if
(
contains
(
info
.
attributes
,
"transA"
))
{
{
transa
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transA"
)).
at
<
bool
>
();
trans
_
a
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transA"
)).
at
<
bool
>
();
}
}
if
(
contains
(
info
.
attributes
,
"transB"
))
if
(
contains
(
info
.
attributes
,
"transB"
))
{
{
transb
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
trans
_
b
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
}
}
std
::
vector
<
int64_t
>
perm
(
args
[
0
]
->
get_shape
().
lens
().
size
());
std
::
vector
<
int64_t
>
perm
=
{
1
,
0
};
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
int64_t
{
0
});
auto
dot_type
=
a_arg
->
get_shape
().
type
();
// swap the last two elements
std
::
swap
(
*
perm
.
rbegin
(),
*
(
perm
.
rbegin
()
+
1
));
auto
l1
=
args
[
0
];
auto
dot_type
=
l1
->
get_shape
().
type
();
if
(
alpha
!=
1.0
f
)
if
(
alpha
!=
1.0
f
)
{
{
auto
alpha_literal
=
info
.
add_literal
(
alpha
);
auto
alpha_literal
=
info
.
add_literal
(
alpha
);
l1
=
info
.
add_broadcastable_binary_op
(
"mul"
,
alpha_literal
,
l1
);
a_arg
=
info
.
add_broadcastable_binary_op
(
"mul"
,
alpha_literal
,
a_arg
);
if
(
l1
->
get_shape
().
type
()
!=
dot_type
)
if
(
a_arg
->
get_shape
().
type
()
!=
dot_type
)
{
{
l1
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
l1
);
a_arg
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
a_arg
);
}
}
}
}
l1
=
a_arg
=
(
trans_a
)
(
transa
)
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
l1
)
:
l1
;
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
a_arg
)
auto
l2
=
(
transb
)
:
a_arg
;
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
1
])
b_arg
=
(
trans_b
)
:
args
[
1
];
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
1
])
:
args
[
1
];
auto
ret
=
info
.
add_instruction
(
make_op
(
"dot"
),
l1
,
l2
);
auto
ret
=
info
.
add_instruction
(
make_op
(
"dot"
),
a_arg
,
b_arg
);
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
if
(
not
float_equal
(
beta
,
0.0
f
)
&&
args
[
2
]
->
get_shape
().
elements
()
>
0
)
// TODO: support dynamic C input
if
(
std
::
any_of
(
args
.
cbegin
(),
args
.
cend
(),
[](
auto
in_arg
)
{
return
in_arg
->
get_shape
().
dynamic
();
}))
{
MIGRAPHX_THROW
(
"PARSE_GEMM: C input not handled for dynamic input shapes"
);
}
if
(
not
float_equal
(
beta
,
0.0
f
)
and
args
[
2
]
->
get_shape
().
elements
()
>
0
)
{
{
auto
out_lens
=
l1
->
get_shape
().
lens
();
auto
out_lens
=
a_arg
->
get_shape
().
lens
();
out_lens
.
back
()
=
l2
->
get_shape
().
lens
().
back
();
out_lens
.
back
()
=
b_arg
->
get_shape
().
lens
().
back
();
auto
l3
=
args
[
2
];
auto
c_arg
=
args
[
2
];
auto
l3
_lens
=
l3
->
get_shape
().
lens
();
auto
c
_lens
=
c_arg
->
get_shape
().
lens
();
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
l3
_lens
.
begin
(),
l3
_lens
.
end
()))
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
c
_lens
.
begin
(),
c
_lens
.
end
()))
{
{
l3
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
c_arg
=
info
.
add_instruction
(
args
[
2
]);
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
}
}
auto
beta_literal
=
info
.
add_literal
(
beta
);
auto
beta_literal
=
info
.
add_literal
(
beta
);
auto
beta_
l3
=
info
.
add_broadcastable_binary_op
(
"mul"
,
l3
,
beta_literal
);
auto
beta_
c
=
info
.
add_broadcastable_binary_op
(
"mul"
,
c_arg
,
beta_literal
);
if
(
beta_
l3
->
get_shape
().
type
()
!=
dot_type
)
if
(
beta_
c
->
get_shape
().
type
()
!=
dot_type
)
{
{
beta_
l3
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
beta_
c
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
beta_
l3
);
beta_
c
);
}
}
return
info
.
add_instruction
(
make_op
(
"add"
),
ret
,
beta_
l3
);
return
info
.
add_instruction
(
make_op
(
"add"
),
ret
,
beta_
c
);
}
}
}
}
...
...
src/onnx/parse_matmul.cpp
View file @
870a396b
...
@@ -43,55 +43,79 @@ struct parse_matmul : op_parser<parse_matmul>
...
@@ -43,55 +43,79 @@ struct parse_matmul : op_parser<parse_matmul>
const
onnx_parser
::
node_info
&
info
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
std
::
vector
<
instruction_ref
>
args
)
const
{
{
auto
l0
=
args
[
0
];
auto
a0
=
args
[
0
];
auto
l1
=
args
[
1
];
auto
a1
=
args
[
1
];
auto
l0_len
s
=
l
0
->
get_shape
()
.
lens
()
;
auto
s
0
=
a
0
->
get_shape
();
auto
l1_len
s
=
l
1
->
get_shape
()
.
lens
()
;
auto
s
1
=
a
1
->
get_shape
();
// args[0] is a vector, prepend 1 to the shape
instruction_ref
dot_res
;
bool
is_a_prepended
=
false
;
bool
is_a_prepended
=
false
;
if
(
l0_lens
.
size
()
==
1
)
bool
is_b_appended
=
false
;
if
(
s0
.
ndim
()
==
1
)
{
{
is_a_prepended
=
true
;
is_a_prepended
=
true
;
l0_lens
.
insert
(
l0_lens
.
begin
(),
1
);
a0
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
args
[
0
]);
l0
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
args
[
0
]);
}
}
if
(
s1
.
ndim
()
==
1
)
bool
is_b_appended
=
false
;
if
(
l1_lens
.
size
()
==
1
)
{
{
is_b_appended
=
true
;
is_b_appended
=
true
;
l1_lens
.
push_back
(
1
);
a1
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
args
[
1
]);
l1
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
args
[
1
]);
}
}
instruction_ref
bl0
=
l0
;
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
instruction_ref
bl1
=
l1
;
if
(
not
std
::
equal
(
l0_lens
.
rbegin
()
+
2
,
l0_lens
.
rend
(),
l1_lens
.
rbegin
()
+
2
,
l1_lens
.
rend
()))
{
{
auto
l0_it
=
l0_lens
.
begin
()
+
l0_lens
.
size
()
-
2
;
if
(
opd
.
op_name
==
"quant_dot"
)
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
l0_lens
.
begin
(),
l0_it
);
{
auto
l1_it
=
l1_lens
.
begin
()
+
l1_lens
.
size
()
-
2
;
MIGRAPHX_THROW
(
"PARSE_MATMUL: dynamic MatMulInteger not supported"
)
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
l1_lens
.
begin
(),
l1_it
);
}
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
auto
s0_dds
=
a0
->
get_shape
().
to_dynamic
().
dyn_dims
(
);
l0_broadcasted_lens
=
output_lens
;
auto
s1_dds
=
a1
->
get_shape
().
to_dynamic
().
dyn_dims
()
;
l0_broadcasted_lens
.
insert
(
l0_broadcasted_lens
.
end
(),
l0_it
,
l0_lens
.
end
());
l1_broadcasted_lens
=
output_lens
;
// TODO: handling this case requires a new multibroadcast mode
l1_broadcasted_lens
.
insert
(
l1_broadcasted_lens
.
end
(),
l1_it
,
l1_lens
.
end
());
if
(
not
std
::
equal
(
if
(
l0_lens
!=
l0_broadcasted_lens
)
s0_dds
.
rbegin
()
+
2
,
s0_dds
.
rend
(),
s1_dds
.
rbegin
()
+
2
,
s1_dds
.
rend
())
)
{
{
bl0
=
info
.
add_instruction
(
MIGRAPHX_THROW
(
"PARSE_MATMUL: dynamic shape broadcasting not supported"
);
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l0_broadcasted_lens
}}),
l0
);
}
}
if
(
l1_lens
!=
l1_broadcasted_lens
)
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
a0
,
a1
);
}
else
{
auto
s0_lens
=
a0
->
get_shape
().
lens
();
auto
s1_lens
=
a1
->
get_shape
().
lens
();
instruction_ref
ba0
=
a0
;
instruction_ref
ba1
=
a1
;
// try broadcasting if dimensions other than last two do not match
if
(
not
std
::
equal
(
s0_lens
.
rbegin
()
+
2
,
s0_lens
.
rend
(),
s1_lens
.
rbegin
()
+
2
,
s1_lens
.
rend
()))
{
{
bl1
=
info
.
add_instruction
(
auto
l0_it
=
s0_lens
.
begin
()
+
s0_lens
.
size
()
-
2
;
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l1_broadcasted_lens
}}),
l1
);
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
s0_lens
.
begin
(),
l0_it
);
auto
l1_it
=
s1_lens
.
begin
()
+
s1_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
s1_lens
.
begin
(),
l1_it
);
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
l0_broadcasted_lens
=
output_lens
;
l0_broadcasted_lens
.
insert
(
l0_broadcasted_lens
.
end
(),
l0_it
,
s0_lens
.
end
());
l1_broadcasted_lens
=
output_lens
;
l1_broadcasted_lens
.
insert
(
l1_broadcasted_lens
.
end
(),
l1_it
,
s1_lens
.
end
());
if
(
s0_lens
!=
l0_broadcasted_lens
)
{
ba0
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l0_broadcasted_lens
}}),
a0
);
}
if
(
s1_lens
!=
l1_broadcasted_lens
)
{
ba1
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l1_broadcasted_lens
}}),
a1
);
}
}
}
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
ba0
,
ba1
);
}
}
instruction_ref
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
bl0
,
bl1
);
int64_t
num_axis
=
static_cast
<
int64_t
>
(
dot_res
->
get_shape
().
lens
().
size
());
// squeeze the appended or prepended dimensions
int64_t
num_axis
=
dot_res
->
get_shape
().
ndim
();
if
(
is_a_prepended
)
if
(
is_a_prepended
)
{
{
dot_res
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
num_axis
-
2
}}}),
dot_res
);
dot_res
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
num_axis
-
2
}}}),
dot_res
);
...
...
src/onnx/parse_pad.cpp
View file @
870a396b
...
@@ -147,7 +147,13 @@ struct parse_pad : op_parser<parse_pad>
...
@@ -147,7 +147,13 @@ struct parse_pad : op_parser<parse_pad>
{
{
auto
mode
=
info
.
attributes
.
at
(
"mode"
).
s
();
auto
mode
=
info
.
attributes
.
at
(
"mode"
).
s
();
if
(
mode
==
"reflect"
)
if
(
mode
==
"reflect"
)
{
if
(
args
.
front
()
->
get_shape
().
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_PAD: reflect padding with dynamic shape not supported"
);
}
return
reflect_pad
(
info
,
pads
,
args
.
front
());
return
reflect_pad
(
info
,
pads
,
args
.
front
());
}
if
(
mode
!=
"constant"
)
if
(
mode
!=
"constant"
)
{
{
MIGRAPHX_THROW
(
MIGRAPHX_THROW
(
...
...
src/onnx/parse_pooling.cpp
View file @
870a396b
...
@@ -47,52 +47,42 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -47,52 +47,42 @@ struct parse_pooling : op_parser<parse_pooling>
{
"GlobalLpPool"
,
"lpnorm"
}};
{
"GlobalLpPool"
,
"lpnorm"
}};
}
}
instruction_ref
parse
(
const
op_desc
&
opd
,
value
handle_values
(
const
op_desc
&
opd
,
const
onnx_parser
&
/*parser*/
,
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
const
shape
&
in_shape
,
std
::
vector
<
instruction_ref
>
arg
s
)
const
value
value
s
)
const
{
{
const
std
::
unordered_map
<
std
::
string
,
op
::
pooling_mode
>
mode_map
=
{
auto
kdims
=
in_shape
.
ndim
()
-
2
;
{
"max"
,
op
::
pooling_mode
::
max
},
{
"average"
,
op
::
pooling_mode
::
average
},
{
"lpnorm"
,
op
::
pooling_mode
::
lpnorm
}};
std
::
string
mode
=
opd
.
op_name
;
if
(
not
contains
(
mode_map
,
mode
))
{
MIGRAPHX_THROW
(
"onnx pooling mode must be [
\"
max
\"
,
\"
average
\"
,
\"
lpnorm
\"
]"
);
}
operation
op
=
make_op
(
"pooling"
,
{{
"mode"
,
mode_map
.
at
(
mode
)}});
value
values
=
op
.
to_value
();
auto
l0
=
args
[
0
];
auto
in_lens
=
l0
->
get_shape
().
lens
();
assert
(
in_lens
.
size
()
>
2
);
auto
kdims
=
in_lens
.
size
()
-
2
;
if
(
starts_with
(
opd
.
onnx_name
,
"Global"
))
if
(
starts_with
(
opd
.
onnx_name
,
"Global"
))
{
{
values
[
"lengths"
]
=
std
::
vector
<
size_t
>
(
in_lens
.
begin
()
+
2
,
in_lens
.
end
());
// if spatial dimensions are dynamic use dyn_global flag
if
(
in_shape
.
dynamic
()
and
std
::
any_of
(
in_shape
.
dyn_dims
().
cbegin
()
+
2
,
in_shape
.
dyn_dims
().
cend
(),
[](
auto
dd
)
{
return
not
dd
.
is_fixed
();
}))
{
values
[
"dyn_global"
]
=
true
;
values
[
"lengths"
]
=
std
::
vector
<
size_t
>
();
}
else
{
// works with static and fixed dynamic shape
auto
m_lens
=
in_shape
.
max_lens
();
values
[
"lengths"
]
=
std
::
vector
<
size_t
>
(
m_lens
.
begin
()
+
2
,
m_lens
.
end
());
}
}
}
// does not support ceil_mode
if
(
contains
(
info
.
attributes
,
"ceil_mode"
))
if
(
contains
(
info
.
attributes
,
"ceil_mode"
))
{
{
values
[
"ceil_mode"
]
=
static_cast
<
bool
>
(
info
.
attributes
.
at
(
"ceil_mode"
).
i
());
values
[
"ceil_mode"
]
=
static_cast
<
bool
>
(
info
.
attributes
.
at
(
"ceil_mode"
).
i
());
}
}
// count include padding, if count include pad is 1, we always use
// explicit pad
int
count_include_pad
=
0
;
if
(
contains
(
info
.
attributes
,
"count_include_pad"
))
{
count_include_pad
=
info
.
attributes
.
at
(
"count_include_pad"
).
i
();
}
if
(
contains
(
info
.
attributes
,
"strides"
))
if
(
contains
(
info
.
attributes
,
"strides"
))
{
{
values
[
"stride"
].
clear
();
values
[
"stride"
].
clear
();
copy
(
info
.
attributes
[
"strides"
].
ints
(),
std
::
back_inserter
(
values
[
"stride"
]));
copy
(
info
.
attributes
[
"strides"
].
ints
(),
std
::
back_inserter
(
values
[
"stride"
]));
check_attr_sizes
(
kdims
,
values
[
"stride"
].
size
(),
"PARSE_POOLING: inconsistent strides"
);
check_attr_sizes
(
kdims
,
values
[
"stride"
].
size
(),
"PARSE_POOLING: inconsistent strides"
);
}
}
if
(
contains
(
info
.
attributes
,
"kernel_shape"
))
if
(
contains
(
info
.
attributes
,
"kernel_shape"
))
{
{
values
[
"lengths"
].
clear
();
values
[
"lengths"
].
clear
();
...
@@ -110,6 +100,46 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -110,6 +100,46 @@ struct parse_pooling : op_parser<parse_pooling>
// ensure pads availabe only when auto_pad is "NOT_SET"
// ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode
(
info
,
"POOLING"
);
check_padding_mode
(
info
,
"POOLING"
);
return
values
;
}
instruction_ref
parse
(
const
op_desc
&
opd
,
const
onnx_parser
&
/*parser*/
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
std
::
string
mode
=
opd
.
op_name
;
const
std
::
unordered_map
<
std
::
string
,
op
::
pooling_mode
>
mode_map
=
{
{
"max"
,
op
::
pooling_mode
::
max
},
{
"average"
,
op
::
pooling_mode
::
average
},
{
"lpnorm"
,
op
::
pooling_mode
::
lpnorm
}};
if
(
not
contains
(
mode_map
,
mode
))
{
MIGRAPHX_THROW
(
"PARSE_POOLING: onnx pooling mode must be [
\"
max
\"
,
\"
average
\"
,
\"
lpnorm
\"
]"
);
}
operation
op
=
make_op
(
"pooling"
,
{{
"mode"
,
mode_map
.
at
(
mode
)}});
value
values
=
op
.
to_value
();
auto
l0
=
args
[
0
];
auto
in_shape
=
l0
->
get_shape
();
assert
(
in_shape
.
ndim
()
>
2
);
auto
kdims
=
in_shape
.
ndim
()
-
2
;
values
=
handle_values
(
opd
,
info
,
in_shape
,
values
);
// count include padding, if count include pad is 1, we always use
// explicit pad
int
count_include_pad
=
0
;
if
(
contains
(
info
.
attributes
,
"count_include_pad"
))
{
if
(
in_shape
.
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape"
);
}
count_include_pad
=
info
.
attributes
.
at
(
"count_include_pad"
).
i
();
}
std
::
vector
<
int64_t
>
paddings
;
std
::
vector
<
int64_t
>
paddings
;
float
pad_val
=
((
mode
==
"max"
)
?
std
::
numeric_limits
<
float
>::
lowest
()
:
0.0
f
);
float
pad_val
=
((
mode
==
"max"
)
?
std
::
numeric_limits
<
float
>::
lowest
()
:
0.0
f
);
...
@@ -123,14 +153,22 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -123,14 +153,22 @@ struct parse_pooling : op_parser<parse_pooling>
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
{
{
values
[
"padding"
].
clear
();
if
(
in_shape
.
dynamic
())
// return paddings could be empty, then setting to 0 for no padding
{
cal_auto_padding_size
(
info
,
MIGRAPHX_THROW
(
values
,
"PARSE_POOLING: Auto padding pooling with dynamic input shape not supported"
);
values
[
"lengths"
].
to_vector
<
std
::
size_t
>
(),
}
{
1
,
1
},
else
in_lens
,
{
paddings
);
values
[
"padding"
].
clear
();
// return paddings could be empty, then setting to 0 for no padding
cal_auto_padding_size
(
info
,
values
,
values
[
"lengths"
].
to_vector
<
std
::
size_t
>
(),
{
1
,
1
},
in_shape
.
lens
(),
paddings
);
}
}
}
if
(
paddings
.
size
()
!=
2
*
kdims
)
if
(
paddings
.
size
()
!=
2
*
kdims
)
...
@@ -150,6 +188,7 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -150,6 +188,7 @@ struct parse_pooling : op_parser<parse_pooling>
values
[
"stride"
].
resize
(
kdims
);
values
[
"stride"
].
resize
(
kdims
);
std
::
fill_n
(
values
[
"stride"
].
begin
(),
kdims
,
1
);
std
::
fill_n
(
values
[
"stride"
].
begin
(),
kdims
,
1
);
}
}
// used to calculate the supposed output shape
// used to calculate the supposed output shape
std
::
vector
<
int64_t
>
orig_padding
=
paddings
;
std
::
vector
<
int64_t
>
orig_padding
=
paddings
;
...
@@ -159,6 +198,11 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -159,6 +198,11 @@ struct parse_pooling : op_parser<parse_pooling>
if
(
not
slice_start
.
empty
())
if
(
not
slice_start
.
empty
())
{
{
if
(
in_shape
.
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape"
);
}
// calculate expected output shape
// calculate expected output shape
orig_padding
.
insert
(
orig_padding
.
begin
()
+
kdims
,
2
,
0
);
orig_padding
.
insert
(
orig_padding
.
begin
()
+
kdims
,
2
,
0
);
orig_padding
.
insert
(
orig_padding
.
begin
(),
2
,
0
);
orig_padding
.
insert
(
orig_padding
.
begin
(),
2
,
0
);
...
...
src/onnx/parse_reduce_op.cpp
View file @
870a396b
...
@@ -68,8 +68,7 @@ instruction_ref parse_reduce_oper(const std::string& op_name,
...
@@ -68,8 +68,7 @@ instruction_ref parse_reduce_oper(const std::string& op_name,
}
}
else
else
{
{
std
::
size_t
n_dim
=
args
.
front
()
->
get_shape
().
lens
().
size
();
axes
.
resize
(
args
.
front
()
->
get_shape
().
ndim
());
axes
.
resize
(
n_dim
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
}
}
}
}
...
...
src/onnx/parse_reshape.cpp
View file @
870a396b
...
@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape>
...
@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape>
if
(
args
.
size
()
==
2
)
if
(
args
.
size
()
==
2
)
{
{
auto
s
=
args
[
1
]
->
eval
();
auto
s
=
args
[
1
]
->
eval
();
check_arg_empty
(
s
,
"Reshape:
dynamic shape
is not supported"
);
check_arg_empty
(
s
,
"Reshape:
non-constant shape input
is not supported"
);
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
}
}
...
...
src/onnx/parse_split.cpp
View file @
870a396b
...
@@ -26,6 +26,9 @@
...
@@ -26,6 +26,9 @@
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/stringutils.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -55,12 +58,12 @@ struct parse_split : op_parser<parse_split>
...
@@ -55,12 +58,12 @@ struct parse_split : op_parser<parse_split>
{
{
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"split"
));
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"split"
));
s
.
visit
([
&
](
auto
v
)
{
vec_splits
.
assign
(
v
.
begin
(),
v
.
end
());
});
s
.
visit
([
&
](
auto
v
)
{
vec_splits
.
assign
(
v
.
begin
(),
v
.
end
());
});
}
if
(
std
::
accumulate
(
vec_splits
.
begin
(),
vec_splits
.
end
(),
int64_t
(
0
))
!=
else
if
(
args
.
size
()
==
2
)
static_cast
<
int64_t
>
(
lens
[
tuned_axis
]))
{
{
auto
s
=
args
[
1
]
->
eval
();
MIGRAPHX_THROW
(
"PARSE_SPLIT: sum of split attribute unequal to dim size of axis!
"
);
check_arg_empty
(
s
,
"Split: dynamic shape is not supported
"
);
}
s
.
visit
([
&
](
auto
v
)
{
vec_splits
.
assign
(
v
.
begin
(),
v
.
end
());
});
}
}
// no split attribute, input is equally divided
// no split attribute, input is equally divided
else
else
...
@@ -74,6 +77,15 @@ struct parse_split : op_parser<parse_split>
...
@@ -74,6 +77,15 @@ struct parse_split : op_parser<parse_split>
vec_splits
.
resize
(
info
.
num_outputs
,
dl
);
vec_splits
.
resize
(
info
.
num_outputs
,
dl
);
}
}
if
(
std
::
accumulate
(
vec_splits
.
begin
(),
vec_splits
.
end
(),
int64_t
(
0
))
!=
static_cast
<
int64_t
>
(
lens
[
tuned_axis
]))
{
MIGRAPHX_THROW
(
"PARSE_SPLIT: sum of split attribute unequal to dim size of axis! tuned axis:"
+
std
::
to_string
(
lens
[
tuned_axis
])
+
" Output "
+
to_string_range
(
vec_splits
)
+
" Rank "
+
std
::
to_string
(
n_rank
)
+
" Len outs "
+
to_string_range
(
lens
));
}
std
::
vector
<
instruction_ref
>
ret_ins
;
std
::
vector
<
instruction_ref
>
ret_ins
;
int64_t
start
=
0
;
int64_t
start
=
0
;
for
(
auto
sl
:
vec_splits
)
for
(
auto
sl
:
vec_splits
)
...
...
src/onnx/parse_transpose.cpp
View file @
870a396b
...
@@ -47,7 +47,7 @@ struct parse_transpose : op_parser<parse_transpose>
...
@@ -47,7 +47,7 @@ struct parse_transpose : op_parser<parse_transpose>
}
}
// if perm is empty, use the default value
// if perm is empty, use the default value
auto
n_dim
=
args
.
front
()
->
get_shape
().
lens
().
size
();
auto
n_dim
=
args
.
front
()
->
get_shape
().
ndim
();
if
(
perm
.
empty
())
if
(
perm
.
empty
())
{
{
perm
.
resize
(
n_dim
);
perm
.
resize
(
n_dim
);
...
...
src/opt/memory_coloring_impl.cpp
View file @
870a396b
...
@@ -72,7 +72,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
...
@@ -72,7 +72,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
{
{
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
const
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
for
(
const
auto
&
iter
:
vn_set
)
for
(
const
auto
&
iter
:
vn_set
)
{
{
live_range
*
range
=
live_ranges
[
iter
];
live_range
*
range
=
live_ranges
[
iter
];
...
@@ -267,8 +267,8 @@ void memory_coloring_impl::verify()
...
@@ -267,8 +267,8 @@ void memory_coloring_impl::verify()
{
{
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
{
{
live_interval
&
interval
=
live_intervals
[
i
];
const
live_interval
&
interval
=
live_intervals
[
i
];
live_range
&
segment
=
interval
.
segment
;
const
live_range
&
segment
=
interval
.
segment
;
if
(
segment
.
begin
==
invalid_offset
)
if
(
segment
.
begin
==
invalid_offset
)
{
{
...
@@ -284,7 +284,7 @@ void memory_coloring_impl::verify()
...
@@ -284,7 +284,7 @@ void memory_coloring_impl::verify()
int
vn
=
segment
.
vn
;
int
vn
=
segment
.
vn
;
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
{
{
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
const
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
for
(
const
auto
&
iter
:
vn_set
)
for
(
const
auto
&
iter
:
vn_set
)
{
{
live_range
*
range
=
live_ranges
[
iter
];
live_range
*
range
=
live_ranges
[
iter
];
...
@@ -319,8 +319,8 @@ void memory_coloring_impl::dump_intervals()
...
@@ -319,8 +319,8 @@ void memory_coloring_impl::dump_intervals()
{
{
std
::
cout
<<
" segment:"
<<
i
;
std
::
cout
<<
" segment:"
<<
i
;
std
::
cout
<<
" =>"
;
std
::
cout
<<
" =>"
;
std
::
set
<
int
>&
table
=
conflict_table
[
i
];
const
std
::
set
<
int
>&
table
=
conflict_table
[
i
];
for
(
auto
&
iter
:
table
)
for
(
const
auto
&
iter
:
table
)
{
{
std
::
cout
<<
(
iter
)
<<
","
;
std
::
cout
<<
(
iter
)
<<
","
;
}
}
...
@@ -357,7 +357,7 @@ void live_interval::dump()
...
@@ -357,7 +357,7 @@ void live_interval::dump()
std
::
cout
<<
"id:"
<<
id
;
std
::
cout
<<
"id:"
<<
id
;
segment
.
dump
();
segment
.
dump
();
std
::
cout
<<
" uses:"
;
std
::
cout
<<
" uses:"
;
for
(
auto
&
iter
:
use_points
)
for
(
const
auto
&
iter
:
use_points
)
{
{
std
::
cout
<<
" "
<<
get_ins_enum
(
iter
)
<<
","
;
std
::
cout
<<
" "
<<
get_ins_enum
(
iter
)
<<
","
;
}
}
...
...
src/pad_calc.cpp
View file @
870a396b
...
@@ -52,19 +52,21 @@ void calculate_padding(int64_t idx,
...
@@ -52,19 +52,21 @@ void calculate_padding(int64_t idx,
}
}
}
}
std
::
vector
<
std
::
size_t
>
calc_dyn_auto_pad
(
std
::
vector
<
std
::
size_t
>
tensor
_lens
,
std
::
vector
<
std
::
size_t
>
calc_dyn_auto_pad
(
const
std
::
vector
<
std
::
size_t
>
&
input
_lens
,
std
::
vector
<
std
::
size_t
>
k
_lens
,
const
std
::
vector
<
std
::
size_t
>
&
wei
_lens
,
std
::
vector
<
std
::
size_t
>
strides
,
const
std
::
vector
<
std
::
size_t
>
&
strides
,
std
::
vector
<
std
::
size_t
>
dilations
,
const
std
::
vector
<
std
::
size_t
>
&
dilations
,
bool
use_upper
)
bool
use_upper
)
{
{
std
::
vector
<
std
::
size_t
>
padding
;
std
::
vector
<
std
::
size_t
>
padding
;
padding
.
resize
(
2
*
k_lens
.
size
());
assert
(
input_lens
.
size
()
>=
3
);
for
(
std
::
size_t
i
=
0
;
i
<
padding
.
size
()
/
2
;
i
++
)
std
::
size_t
num_spatial_dims
=
input_lens
.
size
()
-
2
;
padding
.
resize
(
2
*
num_spatial_dims
);
for
(
std
::
size_t
i
=
0
;
i
<
num_spatial_dims
;
i
++
)
{
{
std
::
ptrdiff_t
input_dim
=
tensor
_lens
[
i
];
std
::
ptrdiff_t
input_dim
=
input
_lens
[
i
+
2
];
std
::
ptrdiff_t
stride
=
strides
[
i
];
std
::
ptrdiff_t
stride
=
strides
[
i
];
std
::
ptrdiff_t
weight_dim
=
k
_lens
[
i
];
std
::
ptrdiff_t
weight_dim
=
wei
_lens
[
i
+
2
];
std
::
ptrdiff_t
dilation
=
dilations
[
i
];
std
::
ptrdiff_t
dilation
=
dilations
[
i
];
std
::
ptrdiff_t
output_dim
=
(
input_dim
+
stride
-
1
)
/
stride
;
// round up result
std
::
ptrdiff_t
output_dim
=
(
input_dim
+
stride
-
1
)
/
stride
;
// round up result
std
::
ptrdiff_t
new_weight_dim
=
weight_dim
+
(
weight_dim
-
1
)
*
(
dilation
-
1
);
std
::
ptrdiff_t
new_weight_dim
=
weight_dim
+
(
weight_dim
-
1
)
*
(
dilation
-
1
);
...
@@ -86,5 +88,28 @@ std::vector<std::size_t> calc_dyn_auto_pad(std::vector<std::size_t> tensor_lens,
...
@@ -86,5 +88,28 @@ std::vector<std::size_t> calc_dyn_auto_pad(std::vector<std::size_t> tensor_lens,
return
padding
;
return
padding
;
}
}
shape
compute_padded_shape
(
const
shape
&
input
,
const
shape
&
weights
,
const
std
::
vector
<
std
::
size_t
>&
padding
,
const
std
::
vector
<
std
::
size_t
>&
stride
,
const
std
::
vector
<
std
::
size_t
>&
dilation
)
{
const
size_t
num_spatial_dims
=
input
.
lens
().
size
()
-
2
;
std
::
vector
<
size_t
>
output_lens
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
]};
// calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
{
auto
padding_factor
=
padding
[
i
]
+
padding
[
i
+
num_spatial_dims
];
output_lens
.
push_back
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
i
+
2
]
-
(
1
+
dilation
[
i
]
*
(
weights
.
lens
()[
i
+
2
]
-
1
))
+
padding_factor
)
/
stride
[
i
]
+
1
)));
}
return
input
.
with_lens
(
output_lens
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/pass_manager.cpp
View file @
870a396b
...
@@ -94,11 +94,19 @@ struct module_pm : module_pass_manager
...
@@ -94,11 +94,19 @@ struct module_pm : module_pass_manager
virtual
void
run_pass
(
const
pass
&
p
)
override
virtual
void
run_pass
(
const
pass
&
p
)
override
{
{
assert
(
mod
);
assert
(
mod
);
timer
ts
{};
using
seconds
=
std
::
chrono
::
duration
<
double
>
;
trace
(
"Module: "
,
mod
->
name
(),
", Pass: "
,
p
.
name
());
trace
(
"Module: "
,
mod
->
name
(),
", Pass: "
,
p
.
name
());
const
double
t1
=
ts
.
record
<
seconds
>
();
assert
(
mod
->
validate
()
==
mod
->
end
());
assert
(
mod
->
validate
()
==
mod
->
end
());
p
.
apply
(
*
this
);
p
.
apply
(
*
this
);
trace
(
*
mod
);
trace
(
*
mod
);
validate_pass
(
*
mod
,
p
,
*
t
);
validate_pass
(
*
mod
,
p
,
*
t
);
const
double
t2
=
ts
.
record
<
seconds
>
();
trace
(
"Pass: "
,
p
.
name
(),
" completed in (s): "
,
(
t2
-
t1
));
}
}
};
};
...
...
Prev
1
2
3
4
5
6
7
8
9
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment