Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
dae94657
"src/vscode:/vscode.git/clone" did not exist on "7d0ec543b7398ce5cf54fb3eb8d441ee8136a1e5"
Unverified
Commit
dae94657
authored
Dec 14, 2022
by
Chris Austen
Committed by
GitHub
Dec 14, 2022
Browse files
Merge branch 'develop' into jit-reduce-reg
parents
b013d991
56c43445
Changes
201
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
515 additions
and
106 deletions
+515
-106
src/include/migraphx/streamutils.hpp
src/include/migraphx/streamutils.hpp
+16
-0
src/insert_pad.cpp
src/insert_pad.cpp
+2
-2
src/instruction.cpp
src/instruction.cpp
+18
-0
src/layout_nhwc.cpp
src/layout_nhwc.cpp
+118
-0
src/load_save.cpp
src/load_save.cpp
+0
-1
src/module.cpp
src/module.cpp
+88
-1
src/onnx/conv.cpp
src/onnx/conv.cpp
+1
-1
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+18
-5
src/onnx/parse_batchnorm.cpp
src/onnx/parse_batchnorm.cpp
+12
-11
src/onnx/parse_binary_op.cpp
src/onnx/parse_binary_op.cpp
+6
-0
src/onnx/parse_convolution.cpp
src/onnx/parse_convolution.cpp
+0
-2
src/onnx/parse_deconvolution.cpp
src/onnx/parse_deconvolution.cpp
+5
-1
src/onnx/parse_pooling.cpp
src/onnx/parse_pooling.cpp
+82
-38
src/onnx/parse_split.cpp
src/onnx/parse_split.cpp
+18
-6
src/onnx/parse_transpose.cpp
src/onnx/parse_transpose.cpp
+1
-1
src/pad_calc.cpp
src/pad_calc.cpp
+33
-8
src/pass_manager.cpp
src/pass_manager.cpp
+8
-0
src/program.cpp
src/program.cpp
+19
-0
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+19
-22
src/shape.cpp
src/shape.cpp
+51
-7
No files found.
src/include/migraphx/streamutils.hpp
View file @
dae94657
...
@@ -26,7 +26,9 @@
...
@@ -26,7 +26,9 @@
#include <ostream>
#include <ostream>
#include <algorithm>
#include <algorithm>
#include <migraphx/reflect.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <vector>
#include <vector>
...
@@ -83,6 +85,20 @@ auto stream_write_value_impl(rank<0>, std::ostream& os, const Range& r)
...
@@ -83,6 +85,20 @@ auto stream_write_value_impl(rank<0>, std::ostream& os, const Range& r)
os
<<
"}"
;
os
<<
"}"
;
}
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_reflectable
<
T
>{})
>
void
stream_write_value_impl
(
rank
<
0
>
,
std
::
ostream
&
os
,
const
T
&
x
)
{
char
delim
=
'{'
;
reflect_each
(
x
,
[
&
](
auto
&&
y
,
auto
name
)
{
os
<<
delim
;
os
<<
name
<<
"="
;
stream_write_value_impl
(
rank
<
2
>
{},
os
,
y
);
delim
=
','
;
});
if
(
delim
==
','
)
os
<<
"}"
;
}
}
// namespace detail
}
// namespace detail
template
<
class
T
>
template
<
class
T
>
...
...
src/insert_pad.cpp
View file @
dae94657
...
@@ -77,14 +77,14 @@ static void update_pooling(const instruction_ref& input, const instruction_ref&
...
@@ -77,14 +77,14 @@ static void update_pooling(const instruction_ref& input, const instruction_ref&
{
{
return
;
return
;
}
}
auto
kdims
=
input
->
get_shape
().
lens
().
size
()
-
2
;
auto
kdims
=
input
->
get_shape
().
ndim
()
-
2
;
if
(
std
::
equal
(
op
.
padding
.
begin
(),
if
(
std
::
equal
(
op
.
padding
.
begin
(),
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
end
()))
op
.
padding
.
end
()))
return
;
return
;
std
::
vector
<
int64_t
>
padding
(
input
->
get_shape
().
lens
().
size
()
*
2
,
0
);
std
::
vector
<
int64_t
>
padding
(
input
->
get_shape
().
ndim
()
*
2
,
0
);
std
::
vector
<
size_t
>
pads_l
(
op
.
padding
.
begin
(),
op
.
padding
.
begin
()
+
kdims
);
std
::
vector
<
size_t
>
pads_l
(
op
.
padding
.
begin
(),
op
.
padding
.
begin
()
+
kdims
);
std
::
vector
<
size_t
>
pads_r
(
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
end
());
std
::
vector
<
size_t
>
pads_r
(
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
end
());
op
.
padding
=
std
::
vector
<
size_t
>
(
kdims
*
2
,
0
);
op
.
padding
=
std
::
vector
<
size_t
>
(
kdims
*
2
,
0
);
...
...
src/instruction.cpp
100644 → 100755
View file @
dae94657
...
@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
...
@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
std
::
replace
(
module_args
.
begin
(),
module_args
.
end
(),
old
,
new_mod
);
std
::
replace
(
module_args
.
begin
(),
module_args
.
end
(),
old
,
new_mod
);
}
}
bool
instruction
::
is_undefined
()
const
{
if
(
op
.
name
()
==
"undefined"
)
{
return
true
;
}
else
if
(
this
->
inputs
().
empty
())
{
return
false
;
}
else
{
return
std
::
all_of
(
this
->
inputs
().
begin
(),
this
->
inputs
().
end
(),
[](
auto
arg
)
{
return
arg
->
is_undefined
();
});
}
}
bool
instruction
::
can_eval
()
const
bool
instruction
::
can_eval
()
const
{
{
if
(
op
.
name
()
==
"@literal"
)
if
(
op
.
name
()
==
"@literal"
)
...
...
src/
rewrite_batchnorm
.cpp
→
src/
layout_nhwc
.cpp
View file @
dae94657
...
@@ -21,63 +21,98 @@
...
@@ -21,63 +21,98 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/
rewrite_batchnorm
.hpp>
#include <migraphx/
layout_nhwc
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/
module
.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
rewrite_batchnorm
::
apply
(
module
&
m
)
const
template
<
class
Predicate
>
std
::
vector
<
instruction_ref
>
find_lasts
(
const
module
&
m
,
Predicate
pred
)
{
std
::
vector
<
instruction_ref
>
result
;
fix
([
&
](
auto
self
,
auto
ins
)
{
if
(
pred
(
ins
))
{
result
.
push_back
(
ins
);
return
;
}
for
(
auto
input
:
ins
->
inputs
())
self
(
input
);
})(
std
::
prev
(
m
.
end
()));
return
result
;
}
std
::
unordered_set
<
instruction_ref
>
preserve_output_layout
(
module
&
m
)
{
std
::
unordered_set
<
instruction_ref
>
result
;
std
::
vector
<
instruction_ref
>
outputs
=
find_lasts
(
m
,
[](
auto
ins
)
{
return
ins
->
name
()
==
"convolution"
and
ins
->
get_shape
().
lens
().
size
()
==
4
;
});
for
(
auto
output
:
outputs
)
{
auto
permutation
=
find_permutation
(
output
->
get_shape
());
auto
layout
=
m
.
insert_instruction
(
std
::
next
(
output
),
make_op
(
"layout"
,
{{
"permutation"
,
permutation
}}),
output
);
result
.
insert
(
m
.
replace_instruction
(
output
,
layout
));
}
return
result
;
}
void
transform_convolutions
(
module
&
m
)
{
{
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()
!=
"
batch_norm_inference
"
)
if
(
ins
->
name
()
!=
"
convolution
"
)
continue
;
continue
;
// Get scale, bias, mean, variance from inputs
if
(
ins
->
get_shape
().
lens
().
size
()
!=
4
)
auto
gamma
=
ins
->
inputs
()[
1
]
->
eval
();
auto
bias
=
ins
->
inputs
()[
2
]
->
eval
();
auto
mean
=
ins
->
inputs
()[
3
]
->
eval
();
auto
variance
=
ins
->
inputs
()[
4
]
->
eval
();
if
(
any_of
({
gamma
,
bias
,
mean
,
variance
},
[](
auto
arg
)
{
return
arg
.
empty
();
}))
continue
;
continue
;
auto
v
=
ins
->
get_operator
().
to_value
();
if
(
v
.
at
(
"group"
).
to
<
int
>
()
>
1
)
continue
;
auto
args
=
ins
->
inputs
();
std
::
transform
(
args
.
begin
(),
args
.
end
(),
args
.
begin
(),
[
&
](
const
auto
&
i
)
{
return
m
.
insert_instruction
(
ins
,
make_op
(
"layout"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
i
);
});
auto
conv
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
args
);
auto
c
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
conv
);
m
.
replace_instruction
(
ins
,
c
);
}
}
std
::
vector
<
std
::
size_t
>
lens
=
ins
->
inputs
()[
1
]
->
get_shape
().
lens
();
void
remove_layout
(
module
&
m
,
const
std
::
unordered_set
<
instruction_ref
>&
output_layouts
)
shape
s
{
ins
->
get_shape
().
type
(),
lens
};
{
// Get epsilon
for
(
auto
ins
:
iterator_for
(
m
))
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
{
auto
epsilon
=
bn_op
.
epsilon
;
if
(
ins
->
name
()
!=
"layout"
)
continue
;
argument
a
{
s
};
if
(
ins
->
get_shape
()
!=
ins
->
inputs
().
front
()
->
get_shape
())
argument
b
{
s
};
continue
;
visit_all
(
gamma
,
bias
,
mean
,
variance
,
a
,
b
)(
if
(
contains
(
output_layouts
,
ins
))
[
&
](
auto
gamma2
,
auto
bias2
,
auto
mean2
,
auto
variance2
,
auto
a2
,
auto
b2
)
{
continue
;
dfor
(
a
.
get_shape
().
elements
())(
m
.
replace_instruction
(
ins
,
ins
->
inputs
().
front
());
[
&
](
std
::
size_t
c
)
{
a2
[
c
]
=
gamma2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
);
});
dfor
(
b
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
b2
[
c
]
=
bias2
[
c
]
-
(
gamma2
[
c
]
*
mean2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
));
});
});
auto
broadcast
=
op
::
broadcast
{
1
,
ins
->
get_shape
().
lens
()};
auto
a_ins
=
m
.
add_literal
({
a
.
get_shape
(),
a
.
data
()});
auto
a_broadcast
=
m
.
insert_instruction
(
ins
,
broadcast
,
a_ins
);
auto
mul
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
ins
->
inputs
().
front
(),
a_broadcast
);
auto
b_ins
=
m
.
add_literal
({
b
.
get_shape
(),
b
.
data
()});
auto
b_broadcast
=
m
.
insert_instruction
(
ins
,
broadcast
,
b_ins
);
auto
add
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
mul
,
b_broadcast
);
m
.
replace_instruction
(
ins
,
add
);
}
}
}
}
void
layout_nhwc
::
apply
(
module_pass_manager
&
mpm
)
const
{
std
::
unordered_set
<
instruction_ref
>
output_layouts
=
preserve_output_layout
(
mpm
.
get_module
());
transform_convolutions
(
mpm
.
get_module
());
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
eliminate_contiguous
{
"contiguous"
});
mpm
.
run_pass
(
dead_code_elimination
{});
remove_layout
(
mpm
.
get_module
(),
output_layouts
);
mpm
.
run_pass
(
dead_code_elimination
{});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/load_save.cpp
View file @
dae94657
...
@@ -25,7 +25,6 @@
...
@@ -25,7 +25,6 @@
#include <migraphx/file_buffer.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/json.hpp>
#include <migraphx/json.hpp>
#include <migraphx/msgpack.hpp>
#include <migraphx/msgpack.hpp>
#include <migraphx/file_buffer.hpp>
#include <fstream>
#include <fstream>
namespace
migraphx
{
namespace
migraphx
{
...
...
src/module.cpp
View file @
dae94657
...
@@ -34,7 +34,6 @@
...
@@ -34,7 +34,6 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/json.hpp>
#include <migraphx/json.hpp>
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
...
@@ -790,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
...
@@ -790,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
return
to_c_id
(
"x_"
+
replace_string
(
name
,
":"
,
"_module_"
));
return
to_c_id
(
"x_"
+
replace_string
(
name
,
":"
,
"_module_"
));
}
}
static
void
print_py_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
auto
v
=
op
.
to_value
();
os
<<
"migraphx.op("
<<
enclose_name
(
op
.
name
());
auto
default_values
=
make_op
(
op
.
name
()).
to_value
();
for
(
auto
&&
x
:
v
)
{
auto
name
=
x
.
get_key
();
if
(
default_values
[
name
]
==
x
)
continue
;
os
<<
", "
<<
name
<<
"="
<<
to_json_string
(
x
.
without_key
());
}
os
<<
")"
;
}
static
void
print_make_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
static
void
print_make_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
{
auto
v
=
op
.
to_value
();
auto
v
=
op
.
to_value
();
...
@@ -805,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op)
...
@@ -805,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op)
os
<<
")"
;
os
<<
")"
;
}
}
static
void
print_py_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
{
os
<<
"migraphx.shape("
<<
s
.
type_string
()
<<
", lens="
<<
to_json_string
(
s
.
lens
());
if
(
not
s
.
standard
())
os
<<
", strides="
<<
to_json_string
(
s
.
strides
());
os
<<
")"
;
}
static
void
print_cpp_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
static
void
print_cpp_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
{
{
os
<<
"migraphx::shape{migraphx::shape::"
<<
s
.
type_string
();
os
<<
"migraphx::shape{migraphx::shape::"
<<
s
.
type_string
();
...
@@ -814,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
...
@@ -814,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
os
<<
"}"
;
os
<<
"}"
;
}
}
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
module
::
print_py
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
)
const
{
// cppcheck-suppress variableScope
unsigned
long
seed
=
names
.
size
();
auto
last
=
std
::
prev
(
this
->
end
());
names
=
this
->
print
(
[
&
](
auto
ins
,
auto
ins_names
)
{
std
::
vector
<
std
::
string
>
input_vars
;
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
input_vars
),
[
&
](
auto
input
)
{
return
cpp_var_name
(
ins_names
.
at
(
input
));
});
if
(
ins
!=
last
)
os
<<
cpp_var_name
(
ins_names
.
at
(
ins
))
<<
" = "
;
if
(
ins
->
name
()
==
"@literal"
)
{
os
<<
mname
<<
".add_literal("
;
bool
use_abs
=
false
;
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
use_abs
=
std
::
none_of
(
v
.
begin
(),
v
.
end
(),
[](
auto
x
)
{
return
x
<
0
;
});
});
// Disable abs for now
use_abs
=
false
;
if
(
use_abs
)
os
<<
"migraphx.abs_literal("
;
os
<<
"migraphx.generate_literal("
;
print_py_shape
(
os
,
ins
->
get_shape
());
os
<<
", "
<<
seed
<<
")"
;
if
(
use_abs
)
os
<<
")"
;
os
<<
")"
<<
std
::
endl
;
seed
++
;
}
else
if
(
ins
->
name
()
==
"@param"
)
{
std
::
string
name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
os
<<
mname
<<
".add_parameter("
<<
enclose_name
(
name
)
<<
","
;
print_py_shape
(
os
,
ins
->
get_shape
());
os
<<
")"
<<
std
::
endl
;
}
else
if
(
ins
->
name
()
==
"@return"
)
{
os
<<
mname
<<
".add_return(["
<<
join_strings
(
input_vars
,
", "
)
<<
"])"
<<
std
::
endl
;
}
else
{
assert
(
ins
->
name
().
front
()
!=
'@'
);
os
<<
mname
<<
".add_instruction("
;
print_py_op
(
os
,
ins
->
get_operator
());
os
<<
", ["
<<
join_strings
(
input_vars
,
", "
)
<<
"]"
;
os
<<
")"
<<
std
::
endl
;
}
},
names
);
return
names
;
}
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
module
::
print_cpp
(
std
::
ostream
&
os
,
module
::
print_cpp
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
const
std
::
string
&
mname
,
...
@@ -875,6 +960,8 @@ module::print_cpp(std::ostream& os,
...
@@ -875,6 +960,8 @@ module::print_cpp(std::ostream& os,
return
names
;
return
names
;
}
}
void
module
::
print_py
(
std
::
ostream
&
os
)
const
{
this
->
print_py
(
os
,
this
->
name
(),
{});
}
void
module
::
print_cpp
(
std
::
ostream
&
os
)
const
{
this
->
print_cpp
(
os
,
this
->
name
(),
{});
}
void
module
::
print_cpp
(
std
::
ostream
&
os
)
const
{
this
->
print_cpp
(
os
,
this
->
name
(),
{});
}
void
module
::
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
void
module
::
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
...
...
src/onnx/conv.cpp
View file @
dae94657
...
@@ -30,7 +30,7 @@ namespace onnx {
...
@@ -30,7 +30,7 @@ namespace onnx {
void
recalc_conv_attributes
(
value
&
v
,
size_t
kdims
)
void
recalc_conv_attributes
(
value
&
v
,
size_t
kdims
)
{
{
if
(
not
(
v
[
"padding"
].
size
()
=
=
kdims
or
v
[
"padding"
].
size
()
=
=
kdims
*
2
)
)
if
(
v
[
"padding"
].
size
()
!
=
kdims
and
v
[
"padding"
].
size
()
!
=
kdims
*
2
)
{
{
v
[
"padding"
].
resize
(
kdims
);
v
[
"padding"
].
resize
(
kdims
);
std
::
fill_n
(
v
[
"padding"
].
begin
(),
kdims
,
0
);
std
::
fill_n
(
v
[
"padding"
].
begin
(),
kdims
,
0
);
...
...
src/onnx/onnx_parser.cpp
View file @
dae94657
...
@@ -393,18 +393,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
...
@@ -393,18 +393,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
literal
onnx_parser
::
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
literal
onnx_parser
::
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
{
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
if
(
not
t
.
external_data
().
empty
())
auto
type
=
get_type
(
t
.
data_type
());
shape
tensor_shape
(
type
,
dims
);
auto
external_data
=
t
.
external_data
();
if
(
not
external_data
.
empty
())
{
{
const
std
::
string
&
data_file
=
t
.
external_data
().
at
(
0
).
value
();
const
std
::
string
&
data_file
=
external_data
.
at
(
0
).
value
();
auto
raw_buffer
=
read_buffer
(
path
+
"/"
+
data_file
);
size_t
num_data_fields
=
external_data
.
size
();
size_t
offset
=
0
;
size_t
nbytes
=
tensor_shape
.
bytes
();
if
(
num_data_fields
>
1
)
// if offset field is present
{
offset
=
std
::
stoul
(
t
.
external_data
().
at
(
1
).
value
());
}
if
(
num_data_fields
>
2
)
// if nbytes field is present
{
nbytes
=
std
::
stoul
(
t
.
external_data
().
at
(
2
).
value
());
}
auto
raw_buffer
=
read_buffer
(
path
+
"/"
+
data_file
,
offset
,
nbytes
);
std
::
string
s
(
raw_buffer
.
begin
(),
raw_buffer
.
end
());
std
::
string
s
(
raw_buffer
.
begin
(),
raw_buffer
.
end
());
auto
type
=
get_type
(
t
.
data_type
());
return
create_literal
(
type
,
dims
,
s
.
data
());
return
create_literal
(
type
,
dims
,
s
.
data
());
}
}
if
(
t
.
has_raw_data
())
if
(
t
.
has_raw_data
())
{
{
const
std
::
string
&
s
=
t
.
raw_data
();
const
std
::
string
&
s
=
t
.
raw_data
();
auto
type
=
get_type
(
t
.
data_type
());
return
create_literal
(
type
,
dims
,
s
.
data
());
return
create_literal
(
type
,
dims
,
s
.
data
());
}
}
...
...
src/onnx/parse_batchnorm.cpp
View file @
dae94657
...
@@ -44,7 +44,7 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
...
@@ -44,7 +44,7 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
{
{
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"epsilon"
)).
at
<
float
>
();
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"epsilon"
)).
at
<
float
>
();
}
}
auto
x_lens
=
args
[
0
]
->
get_shape
().
lens
();
auto
x_lens
=
args
[
0
]
->
get_shape
().
max_
lens
();
auto
x_type
=
args
[
0
]
->
get_shape
().
type
();
auto
x_type
=
args
[
0
]
->
get_shape
().
type
();
if
(
std
::
any_of
(
args
.
cbegin
()
+
1
,
args
.
cend
(),
[](
auto
a
)
{
if
(
std
::
any_of
(
args
.
cbegin
()
+
1
,
args
.
cend
(),
[](
auto
a
)
{
...
@@ -54,18 +54,19 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
...
@@ -54,18 +54,19 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
MIGRAPHX_THROW
(
"PARSE_BATCHNORM: argument scale, bias, mean, or var rank != 1"
);
MIGRAPHX_THROW
(
"PARSE_BATCHNORM: argument scale, bias, mean, or var rank != 1"
);
}
}
if
(
x_lens
.
size
()
==
1
)
auto
x_rank
=
x_lens
.
size
();
if
(
x_rank
==
1
or
x_rank
==
2
)
{
{
auto
rt
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
0.5
}});
auto
rt
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
0.5
}});
auto
eps
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
epsilon
}});
auto
eps
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
epsilon
}});
auto
n
0
=
info
.
add_broadcastable_binary_op
(
"sub"
,
args
[
0
],
args
[
3
]);
auto
n
umer
=
info
.
add_broadcastable_binary_op
(
"sub"
,
args
[
0
],
args
[
3
]);
auto
d0
=
info
.
add_broadcastable_binary_op
(
"add"
,
args
[
4
],
eps
);
auto
var_eps
=
info
.
add_broadcastable_binary_op
(
"add"
,
args
[
4
],
eps
);
auto
d
1
=
info
.
add_broadcastable_binary_op
(
"pow"
,
d0
,
rt
);
auto
d
enom
=
info
.
add_broadcastable_binary_op
(
"pow"
,
var_eps
,
rt
);
auto
div0
=
info
.
add_broadcastable_binary_op
(
"div"
,
n
0
,
d1
);
auto
div0
=
info
.
add_broadcastable_binary_op
(
"div"
,
n
umer
,
denom
);
auto
r0
=
info
.
add_broadcastable_binary_op
(
"mul"
,
div0
,
args
[
1
]);
auto
r0
=
info
.
add_broadcastable_binary_op
(
"mul"
,
div0
,
args
[
1
]);
return
info
.
add_broadcastable_binary_op
(
"add"
,
r0
,
args
[
2
]);
return
info
.
add_broadcastable_binary_op
(
"add"
,
r0
,
args
[
2
]);
}
}
else
if
(
x_
lens
.
size
()
>
2
)
else
if
(
x_
rank
>
2
)
{
{
// unsqueeze tensors of shape (C) to broadcast correctly
// unsqueeze tensors of shape (C) to broadcast correctly
std
::
vector
<
int64_t
>
unsqueeze_axes
(
x_lens
.
size
()
-
2
);
std
::
vector
<
int64_t
>
unsqueeze_axes
(
x_lens
.
size
()
-
2
);
...
@@ -89,7 +90,7 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
...
@@ -89,7 +90,7 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
}
}
else
else
{
{
//
num dims either 0 or 2
//
rank == 0
MIGRAPHX_THROW
(
"PARSE_BATCHNORM: rank "
+
std
::
to_string
(
x_lens
.
size
())
+
MIGRAPHX_THROW
(
"PARSE_BATCHNORM: rank "
+
std
::
to_string
(
x_lens
.
size
())
+
" input tensor, unhandled data format"
);
" input tensor, unhandled data format"
);
}
}
...
...
src/onnx/parse_binary_op.cpp
View file @
dae94657
...
@@ -57,6 +57,12 @@ struct parse_binary_op : op_parser<parse_binary_op>
...
@@ -57,6 +57,12 @@ struct parse_binary_op : op_parser<parse_binary_op>
parser
.
parse_value
(
info
.
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
parser
.
parse_value
(
info
.
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcasted
!=
0
)
if
(
broadcasted
!=
0
)
{
{
if
(
std
::
any_of
(
args
.
cbegin
(),
args
.
cend
(),
[](
auto
a
)
{
return
a
->
get_shape
().
dynamic
();
}))
{
MIGRAPHX_THROW
(
"Binary op broadcast attribute not supported for dynamic input shapes"
);
}
uint64_t
axis
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
();
uint64_t
axis
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
();
auto
l
=
info
.
add_instruction
(
auto
l
=
info
.
add_instruction
(
make_op
(
"broadcast"
,
make_op
(
"broadcast"
,
...
...
src/onnx/parse_convolution.cpp
View file @
dae94657
...
@@ -125,11 +125,9 @@ struct parse_convolution : op_parser<parse_convolution>
...
@@ -125,11 +125,9 @@ struct parse_convolution : op_parser<parse_convolution>
values
[
"padding_mode"
]
=
is_same_upper
values
[
"padding_mode"
]
=
is_same_upper
?
to_value
(
op
::
padding_mode_t
::
same_upper
)
?
to_value
(
op
::
padding_mode_t
::
same_upper
)
:
to_value
(
op
::
padding_mode_t
::
same_lower
);
:
to_value
(
op
::
padding_mode_t
::
same_lower
);
values
[
"use_dynamic_same_auto_pad"
]
=
true
;
}
}
else
else
{
{
values
[
"padding_mode"
]
=
to_value
(
op
::
padding_mode_t
::
same
);
// kernel shape will be fixed, so max_lens() == min_len() for kernel lengths
// kernel shape will be fixed, so max_lens() == min_len() for kernel lengths
auto
weight_lens
=
weights
->
get_shape
().
max_lens
();
auto
weight_lens
=
weights
->
get_shape
().
max_lens
();
std
::
vector
<
std
::
size_t
>
k_lens
(
weight_lens
.
begin
()
+
2
,
weight_lens
.
end
());
std
::
vector
<
std
::
size_t
>
k_lens
(
weight_lens
.
begin
()
+
2
,
weight_lens
.
end
());
...
...
src/onnx/parse_deconvolution.cpp
View file @
dae94657
...
@@ -95,6 +95,8 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
...
@@ -95,6 +95,8 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
check_attr_sizes
(
check_attr_sizes
(
kdims
,
values
[
"dilation"
].
size
(),
"PARSE_CONV_TRANSPOSE: inconsistent dilations"
);
kdims
,
values
[
"dilation"
].
size
(),
"PARSE_CONV_TRANSPOSE: inconsistent dilations"
);
}
}
// TODO: auto padding needs to be implemented for this parser and operator
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
{
{
auto
s
=
info
.
attributes
[
"auto_pad"
].
s
();
auto
s
=
info
.
attributes
[
"auto_pad"
].
s
();
...
@@ -106,7 +108,9 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
...
@@ -106,7 +108,9 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
if
(
s
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
if
(
s
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
{
values
[
"padding_mode"
]
=
to_value
(
op
::
padding_mode_t
::
same
);
bool
is_same_upper
=
(
s
.
find
(
"SAME_UPPER"
)
!=
std
::
string
::
npos
);
values
[
"padding_mode"
]
=
is_same_upper
?
to_value
(
op
::
padding_mode_t
::
same_upper
)
:
to_value
(
op
::
padding_mode_t
::
same_lower
);
}
}
}
}
...
...
src/onnx/parse_pooling.cpp
View file @
dae94657
...
@@ -47,52 +47,42 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -47,52 +47,42 @@ struct parse_pooling : op_parser<parse_pooling>
{
"GlobalLpPool"
,
"lpnorm"
}};
{
"GlobalLpPool"
,
"lpnorm"
}};
}
}
instruction_ref
parse
(
const
op_desc
&
opd
,
value
handle_values
(
const
op_desc
&
opd
,
const
onnx_parser
&
/*parser*/
,
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
const
shape
&
in_shape
,
std
::
vector
<
instruction_ref
>
arg
s
)
const
value
value
s
)
const
{
{
const
std
::
unordered_map
<
std
::
string
,
op
::
pooling_mode
>
mode_map
=
{
auto
kdims
=
in_shape
.
ndim
()
-
2
;
{
"max"
,
op
::
pooling_mode
::
max
},
{
"average"
,
op
::
pooling_mode
::
average
},
{
"lpnorm"
,
op
::
pooling_mode
::
lpnorm
}};
std
::
string
mode
=
opd
.
op_name
;
if
(
not
contains
(
mode_map
,
mode
))
{
MIGRAPHX_THROW
(
"onnx pooling mode must be [
\"
max
\"
,
\"
average
\"
,
\"
lpnorm
\"
]"
);
}
operation
op
=
make_op
(
"pooling"
,
{{
"mode"
,
mode_map
.
at
(
mode
)}});
value
values
=
op
.
to_value
();
auto
l0
=
args
[
0
];
auto
in_lens
=
l0
->
get_shape
().
lens
();
assert
(
in_lens
.
size
()
>
2
);
auto
kdims
=
in_lens
.
size
()
-
2
;
if
(
starts_with
(
opd
.
onnx_name
,
"Global"
))
if
(
starts_with
(
opd
.
onnx_name
,
"Global"
))
{
{
values
[
"lengths"
]
=
std
::
vector
<
size_t
>
(
in_lens
.
begin
()
+
2
,
in_lens
.
end
());
// if spatial dimensions are dynamic use dyn_global flag
if
(
in_shape
.
dynamic
()
and
std
::
any_of
(
in_shape
.
dyn_dims
().
cbegin
()
+
2
,
in_shape
.
dyn_dims
().
cend
(),
[](
auto
dd
)
{
return
not
dd
.
is_fixed
();
}))
{
values
[
"dyn_global"
]
=
true
;
values
[
"lengths"
]
=
std
::
vector
<
size_t
>
();
}
else
{
// works with static and fixed dynamic shape
auto
m_lens
=
in_shape
.
max_lens
();
values
[
"lengths"
]
=
std
::
vector
<
size_t
>
(
m_lens
.
begin
()
+
2
,
m_lens
.
end
());
}
}
}
// does not support ceil_mode
if
(
contains
(
info
.
attributes
,
"ceil_mode"
))
if
(
contains
(
info
.
attributes
,
"ceil_mode"
))
{
{
values
[
"ceil_mode"
]
=
static_cast
<
bool
>
(
info
.
attributes
.
at
(
"ceil_mode"
).
i
());
values
[
"ceil_mode"
]
=
static_cast
<
bool
>
(
info
.
attributes
.
at
(
"ceil_mode"
).
i
());
}
}
// count include padding, if count include pad is 1, we always use
// explicit pad
int
count_include_pad
=
0
;
if
(
contains
(
info
.
attributes
,
"count_include_pad"
))
{
count_include_pad
=
info
.
attributes
.
at
(
"count_include_pad"
).
i
();
}
if
(
contains
(
info
.
attributes
,
"strides"
))
if
(
contains
(
info
.
attributes
,
"strides"
))
{
{
values
[
"stride"
].
clear
();
values
[
"stride"
].
clear
();
copy
(
info
.
attributes
[
"strides"
].
ints
(),
std
::
back_inserter
(
values
[
"stride"
]));
copy
(
info
.
attributes
[
"strides"
].
ints
(),
std
::
back_inserter
(
values
[
"stride"
]));
check_attr_sizes
(
kdims
,
values
[
"stride"
].
size
(),
"PARSE_POOLING: inconsistent strides"
);
check_attr_sizes
(
kdims
,
values
[
"stride"
].
size
(),
"PARSE_POOLING: inconsistent strides"
);
}
}
if
(
contains
(
info
.
attributes
,
"kernel_shape"
))
if
(
contains
(
info
.
attributes
,
"kernel_shape"
))
{
{
values
[
"lengths"
].
clear
();
values
[
"lengths"
].
clear
();
...
@@ -110,6 +100,46 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -110,6 +100,46 @@ struct parse_pooling : op_parser<parse_pooling>
// ensure pads availabe only when auto_pad is "NOT_SET"
// ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode
(
info
,
"POOLING"
);
check_padding_mode
(
info
,
"POOLING"
);
return
values
;
}
instruction_ref
parse
(
const
op_desc
&
opd
,
const
onnx_parser
&
/*parser*/
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
std
::
string
mode
=
opd
.
op_name
;
const
std
::
unordered_map
<
std
::
string
,
op
::
pooling_mode
>
mode_map
=
{
{
"max"
,
op
::
pooling_mode
::
max
},
{
"average"
,
op
::
pooling_mode
::
average
},
{
"lpnorm"
,
op
::
pooling_mode
::
lpnorm
}};
if
(
not
contains
(
mode_map
,
mode
))
{
MIGRAPHX_THROW
(
"PARSE_POOLING: onnx pooling mode must be [
\"
max
\"
,
\"
average
\"
,
\"
lpnorm
\"
]"
);
}
operation
op
=
make_op
(
"pooling"
,
{{
"mode"
,
mode_map
.
at
(
mode
)}});
value
values
=
op
.
to_value
();
auto
l0
=
args
[
0
];
auto
in_shape
=
l0
->
get_shape
();
assert
(
in_shape
.
ndim
()
>
2
);
auto
kdims
=
in_shape
.
ndim
()
-
2
;
values
=
handle_values
(
opd
,
info
,
in_shape
,
values
);
// count include padding, if count include pad is 1, we always use
// explicit pad
int
count_include_pad
=
0
;
if
(
contains
(
info
.
attributes
,
"count_include_pad"
))
{
if
(
in_shape
.
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape"
);
}
count_include_pad
=
info
.
attributes
.
at
(
"count_include_pad"
).
i
();
}
std
::
vector
<
int64_t
>
paddings
;
std
::
vector
<
int64_t
>
paddings
;
float
pad_val
=
((
mode
==
"max"
)
?
std
::
numeric_limits
<
float
>::
lowest
()
:
0.0
f
);
float
pad_val
=
((
mode
==
"max"
)
?
std
::
numeric_limits
<
float
>::
lowest
()
:
0.0
f
);
...
@@ -123,14 +153,22 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -123,14 +153,22 @@ struct parse_pooling : op_parser<parse_pooling>
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
{
{
values
[
"padding"
].
clear
();
if
(
in_shape
.
dynamic
())
// return paddings could be empty, then setting to 0 for no padding
{
cal_auto_padding_size
(
info
,
MIGRAPHX_THROW
(
values
,
"PARSE_POOLING: Auto padding pooling with dynamic input shape not supported"
);
values
[
"lengths"
].
to_vector
<
std
::
size_t
>
(),
}
{
1
,
1
},
else
in_lens
,
{
paddings
);
values
[
"padding"
].
clear
();
// return paddings could be empty, then setting to 0 for no padding
cal_auto_padding_size
(
info
,
values
,
values
[
"lengths"
].
to_vector
<
std
::
size_t
>
(),
{
1
,
1
},
in_shape
.
lens
(),
paddings
);
}
}
}
if
(
paddings
.
size
()
!=
2
*
kdims
)
if
(
paddings
.
size
()
!=
2
*
kdims
)
...
@@ -150,6 +188,7 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -150,6 +188,7 @@ struct parse_pooling : op_parser<parse_pooling>
values
[
"stride"
].
resize
(
kdims
);
values
[
"stride"
].
resize
(
kdims
);
std
::
fill_n
(
values
[
"stride"
].
begin
(),
kdims
,
1
);
std
::
fill_n
(
values
[
"stride"
].
begin
(),
kdims
,
1
);
}
}
// used to calculate the supposed output shape
// used to calculate the supposed output shape
std
::
vector
<
int64_t
>
orig_padding
=
paddings
;
std
::
vector
<
int64_t
>
orig_padding
=
paddings
;
...
@@ -159,6 +198,11 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -159,6 +198,11 @@ struct parse_pooling : op_parser<parse_pooling>
if
(
not
slice_start
.
empty
())
if
(
not
slice_start
.
empty
())
{
{
if
(
in_shape
.
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape"
);
}
// calculate expected output shape
// calculate expected output shape
orig_padding
.
insert
(
orig_padding
.
begin
()
+
kdims
,
2
,
0
);
orig_padding
.
insert
(
orig_padding
.
begin
()
+
kdims
,
2
,
0
);
orig_padding
.
insert
(
orig_padding
.
begin
(),
2
,
0
);
orig_padding
.
insert
(
orig_padding
.
begin
(),
2
,
0
);
...
...
src/onnx/parse_split.cpp
View file @
dae94657
...
@@ -26,6 +26,9 @@
...
@@ -26,6 +26,9 @@
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/stringutils.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -55,12 +58,12 @@ struct parse_split : op_parser<parse_split>
...
@@ -55,12 +58,12 @@ struct parse_split : op_parser<parse_split>
{
{
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"split"
));
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"split"
));
s
.
visit
([
&
](
auto
v
)
{
vec_splits
.
assign
(
v
.
begin
(),
v
.
end
());
});
s
.
visit
([
&
](
auto
v
)
{
vec_splits
.
assign
(
v
.
begin
(),
v
.
end
());
});
}
if
(
std
::
accumulate
(
vec_splits
.
begin
(),
vec_splits
.
end
(),
int64_t
(
0
))
!=
else
if
(
args
.
size
()
==
2
)
static_cast
<
int64_t
>
(
lens
[
tuned_axis
]))
{
{
auto
s
=
args
[
1
]
->
eval
();
MIGRAPHX_THROW
(
"PARSE_SPLIT: sum of split attribute unequal to dim size of axis!
"
);
check_arg_empty
(
s
,
"Split: dynamic shape is not supported
"
);
}
s
.
visit
([
&
](
auto
v
)
{
vec_splits
.
assign
(
v
.
begin
(),
v
.
end
());
});
}
}
// no split attribute, input is equally divided
// no split attribute, input is equally divided
else
else
...
@@ -74,6 +77,15 @@ struct parse_split : op_parser<parse_split>
...
@@ -74,6 +77,15 @@ struct parse_split : op_parser<parse_split>
vec_splits
.
resize
(
info
.
num_outputs
,
dl
);
vec_splits
.
resize
(
info
.
num_outputs
,
dl
);
}
}
if
(
std
::
accumulate
(
vec_splits
.
begin
(),
vec_splits
.
end
(),
int64_t
(
0
))
!=
static_cast
<
int64_t
>
(
lens
[
tuned_axis
]))
{
MIGRAPHX_THROW
(
"PARSE_SPLIT: sum of split attribute unequal to dim size of axis! tuned axis:"
+
std
::
to_string
(
lens
[
tuned_axis
])
+
" Output "
+
to_string_range
(
vec_splits
)
+
" Rank "
+
std
::
to_string
(
n_rank
)
+
" Len outs "
+
to_string_range
(
lens
));
}
std
::
vector
<
instruction_ref
>
ret_ins
;
std
::
vector
<
instruction_ref
>
ret_ins
;
int64_t
start
=
0
;
int64_t
start
=
0
;
for
(
auto
sl
:
vec_splits
)
for
(
auto
sl
:
vec_splits
)
...
...
src/onnx/parse_transpose.cpp
View file @
dae94657
...
@@ -47,7 +47,7 @@ struct parse_transpose : op_parser<parse_transpose>
...
@@ -47,7 +47,7 @@ struct parse_transpose : op_parser<parse_transpose>
}
}
// if perm is empty, use the default value
// if perm is empty, use the default value
auto
n_dim
=
args
.
front
()
->
get_shape
().
lens
().
size
();
auto
n_dim
=
args
.
front
()
->
get_shape
().
ndim
();
if
(
perm
.
empty
())
if
(
perm
.
empty
())
{
{
perm
.
resize
(
n_dim
);
perm
.
resize
(
n_dim
);
...
...
src/pad_calc.cpp
View file @
dae94657
...
@@ -52,19 +52,21 @@ void calculate_padding(int64_t idx,
...
@@ -52,19 +52,21 @@ void calculate_padding(int64_t idx,
}
}
}
}
std
::
vector
<
std
::
size_t
>
calc_dyn_auto_pad
(
std
::
vector
<
std
::
size_t
>
tensor
_lens
,
std
::
vector
<
std
::
size_t
>
calc_dyn_auto_pad
(
const
std
::
vector
<
std
::
size_t
>
&
input
_lens
,
std
::
vector
<
std
::
size_t
>
k
_lens
,
const
std
::
vector
<
std
::
size_t
>
&
wei
_lens
,
std
::
vector
<
std
::
size_t
>
strides
,
const
std
::
vector
<
std
::
size_t
>
&
strides
,
std
::
vector
<
std
::
size_t
>
dilations
,
const
std
::
vector
<
std
::
size_t
>
&
dilations
,
bool
use_upper
)
bool
use_upper
)
{
{
std
::
vector
<
std
::
size_t
>
padding
;
std
::
vector
<
std
::
size_t
>
padding
;
padding
.
resize
(
2
*
k_lens
.
size
());
assert
(
input_lens
.
size
()
>=
3
);
for
(
std
::
size_t
i
=
0
;
i
<
padding
.
size
()
/
2
;
i
++
)
std
::
size_t
num_spatial_dims
=
input_lens
.
size
()
-
2
;
padding
.
resize
(
2
*
num_spatial_dims
);
for
(
std
::
size_t
i
=
0
;
i
<
num_spatial_dims
;
i
++
)
{
{
std
::
ptrdiff_t
input_dim
=
tensor
_lens
[
i
];
std
::
ptrdiff_t
input_dim
=
input
_lens
[
i
+
2
];
std
::
ptrdiff_t
stride
=
strides
[
i
];
std
::
ptrdiff_t
stride
=
strides
[
i
];
std
::
ptrdiff_t
weight_dim
=
k
_lens
[
i
];
std
::
ptrdiff_t
weight_dim
=
wei
_lens
[
i
+
2
];
std
::
ptrdiff_t
dilation
=
dilations
[
i
];
std
::
ptrdiff_t
dilation
=
dilations
[
i
];
std
::
ptrdiff_t
output_dim
=
(
input_dim
+
stride
-
1
)
/
stride
;
// round up result
std
::
ptrdiff_t
output_dim
=
(
input_dim
+
stride
-
1
)
/
stride
;
// round up result
std
::
ptrdiff_t
new_weight_dim
=
weight_dim
+
(
weight_dim
-
1
)
*
(
dilation
-
1
);
std
::
ptrdiff_t
new_weight_dim
=
weight_dim
+
(
weight_dim
-
1
)
*
(
dilation
-
1
);
...
@@ -86,5 +88,28 @@ std::vector<std::size_t> calc_dyn_auto_pad(std::vector<std::size_t> tensor_lens,
...
@@ -86,5 +88,28 @@ std::vector<std::size_t> calc_dyn_auto_pad(std::vector<std::size_t> tensor_lens,
return
padding
;
return
padding
;
}
}
shape
compute_padded_shape
(
const
shape
&
input
,
const
shape
&
weights
,
const
std
::
vector
<
std
::
size_t
>&
padding
,
const
std
::
vector
<
std
::
size_t
>&
stride
,
const
std
::
vector
<
std
::
size_t
>&
dilation
)
{
const
size_t
num_spatial_dims
=
input
.
lens
().
size
()
-
2
;
std
::
vector
<
size_t
>
output_lens
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
]};
// calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
{
auto
padding_factor
=
padding
[
i
]
+
padding
[
i
+
num_spatial_dims
];
output_lens
.
push_back
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
i
+
2
]
-
(
1
+
dilation
[
i
]
*
(
weights
.
lens
()[
i
+
2
]
-
1
))
+
padding_factor
)
/
stride
[
i
]
+
1
)));
}
return
input
.
with_lens
(
output_lens
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/pass_manager.cpp
View file @
dae94657
...
@@ -94,11 +94,19 @@ struct module_pm : module_pass_manager
...
@@ -94,11 +94,19 @@ struct module_pm : module_pass_manager
virtual
void
run_pass
(
const
pass
&
p
)
override
virtual
void
run_pass
(
const
pass
&
p
)
override
{
{
assert
(
mod
);
assert
(
mod
);
timer
ts
{};
using
seconds
=
std
::
chrono
::
duration
<
double
>
;
trace
(
"Module: "
,
mod
->
name
(),
", Pass: "
,
p
.
name
());
trace
(
"Module: "
,
mod
->
name
(),
", Pass: "
,
p
.
name
());
const
double
t1
=
ts
.
record
<
seconds
>
();
assert
(
mod
->
validate
()
==
mod
->
end
());
assert
(
mod
->
validate
()
==
mod
->
end
());
p
.
apply
(
*
this
);
p
.
apply
(
*
this
);
trace
(
*
mod
);
trace
(
*
mod
);
validate_pass
(
*
mod
,
p
,
*
t
);
validate_pass
(
*
mod
,
p
,
*
t
);
const
double
t2
=
ts
.
record
<
seconds
>
();
trace
(
"Pass: "
,
p
.
name
(),
" completed in (s): "
,
(
t2
-
t1
));
}
}
};
};
...
...
src/program.cpp
View file @
dae94657
...
@@ -854,6 +854,25 @@ void program::print_graph(std::ostream& os, bool brief) const
...
@@ -854,6 +854,25 @@ void program::print_graph(std::ostream& os, bool brief) const
mm
->
print_graph
(
os
,
brief
);
mm
->
print_graph
(
os
,
brief
);
}
}
void
program
::
print_py
(
std
::
ostream
&
os
)
const
{
auto
vec_modules
=
this
->
get_modules
();
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
os
<<
"p = migraphx.program()
\n
"
;
for
(
auto
&
mod
:
vec_modules
)
{
std
::
string
var_name
=
"m"
+
mod
->
name
();
os
<<
var_name
<<
" = "
;
if
(
mod
->
name
()
==
"main"
)
os
<<
"p.get_main_module()"
;
else
os
<<
"p.create_module(
\"
"
<<
mod
->
name
()
<<
"
\"
);"
;
os
<<
std
::
endl
;
names
=
mod
->
print_py
(
os
,
var_name
,
names
);
os
<<
std
::
endl
;
}
}
void
program
::
print_cpp
(
std
::
ostream
&
os
)
const
void
program
::
print_cpp
(
std
::
ostream
&
os
)
const
{
{
auto
vec_modules
=
this
->
get_modules
();
auto
vec_modules
=
this
->
get_modules
();
...
...
src/rewrite_rnn.cpp
View file @
dae94657
...
@@ -46,9 +46,6 @@
...
@@ -46,9 +46,6 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -95,7 +92,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
...
@@ -95,7 +92,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process sequence length
// process sequence length
instruction_ref
seq_lens
=
m
.
end
();
instruction_ref
seq_lens
=
m
.
end
();
if
((
args
.
size
()
>=
5
)
&&
args
[
4
]
->
name
()
!=
"
undefined
"
)
if
((
args
.
size
()
>=
5
)
and
not
args
[
4
]
->
is_
undefined
()
)
{
{
seq_lens
=
args
[
4
];
seq_lens
=
args
[
4
];
}
}
...
@@ -120,7 +117,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
...
@@ -120,7 +117,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process bias
// process bias
instruction_ref
bias_forward
=
m
.
end
();
instruction_ref
bias_forward
=
m
.
end
();
instruction_ref
bias_reverse
=
m
.
end
();
instruction_ref
bias_reverse
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
4
and
not
args
[
3
]
->
is_
undefined
()
)
{
{
bias_forward
=
m
.
insert_instruction
(
bias_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
3
]);
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
3
]);
...
@@ -132,7 +129,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
...
@@ -132,7 +129,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// or the 5th one (if the sequence len argument is ignored)
// or the 5th one (if the sequence len argument is ignored)
instruction_ref
ih_forward
{};
instruction_ref
ih_forward
{};
instruction_ref
ih_reverse
{};
instruction_ref
ih_reverse
{};
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
==
6
and
not
args
[
5
]
->
is_
undefined
()
)
{
{
ih_forward
=
m
.
insert_instruction
(
ih_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
5
]);
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
5
]);
...
@@ -198,14 +195,14 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
...
@@ -198,14 +195,14 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process bias and initial hidden state
// process bias and initial hidden state
instruction_ref
bias
=
m
.
end
();
instruction_ref
bias
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
4
and
not
args
[
3
]
->
is_
undefined
()
)
{
{
bias
=
args
[
3
];
bias
=
args
[
3
];
}
}
// process intial hidden state
// process intial hidden state
instruction_ref
ih
;
instruction_ref
ih
;
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
==
6
and
not
args
[
5
]
->
is_
undefined
()
)
{
{
ih
=
args
[
5
];
ih
=
args
[
5
];
}
}
...
@@ -401,7 +398,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
...
@@ -401,7 +398,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// process sequence length
// process sequence length
instruction_ref
seq_lens
=
m
.
end
();
instruction_ref
seq_lens
=
m
.
end
();
if
((
args
.
size
()
>=
5
)
&&
args
[
4
]
->
name
()
!=
"
undefined
"
)
if
((
args
.
size
()
>=
5
)
and
not
args
[
4
]
->
is_
undefined
()
)
{
{
seq_lens
=
args
[
4
];
seq_lens
=
args
[
4
];
}
}
...
@@ -426,7 +423,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
...
@@ -426,7 +423,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// bias
// bias
instruction_ref
bias_forward
=
m
.
end
();
instruction_ref
bias_forward
=
m
.
end
();
instruction_ref
bias_reverse
=
m
.
end
();
instruction_ref
bias_reverse
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
4
and
not
args
[
3
]
->
is_
undefined
()
)
{
{
bias_forward
=
m
.
insert_instruction
(
bias_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
3
]);
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
3
]);
...
@@ -437,7 +434,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
...
@@ -437,7 +434,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// intial hidden state
// intial hidden state
instruction_ref
ih_forward
{};
instruction_ref
ih_forward
{};
instruction_ref
ih_reverse
{};
instruction_ref
ih_reverse
{};
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
==
6
and
not
args
[
5
]
->
is_
undefined
()
)
{
{
ih_forward
=
m
.
insert_instruction
(
ih_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
5
]);
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
5
]);
...
@@ -504,14 +501,14 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
...
@@ -504,14 +501,14 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// bias
// bias
instruction_ref
bias
=
m
.
end
();
instruction_ref
bias
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
4
and
not
args
[
3
]
->
is_
undefined
()
)
{
{
bias
=
args
[
3
];
bias
=
args
[
3
];
}
}
// intial hidden state
// intial hidden state
instruction_ref
ih
{};
instruction_ref
ih
{};
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
==
6
and
not
args
[
5
]
->
is_
undefined
()
)
{
{
ih
=
args
[
5
];
ih
=
args
[
5
];
}
}
...
@@ -787,7 +784,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -787,7 +784,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process sequence length
// process sequence length
instruction_ref
seq_lens
=
m
.
end
();
instruction_ref
seq_lens
=
m
.
end
();
if
((
args
.
size
()
>=
5
)
&&
args
[
4
]
->
name
()
!=
"
undefined
"
)
if
((
args
.
size
()
>=
5
)
and
not
args
[
4
]
->
is_
undefined
()
)
{
{
seq_lens
=
args
[
4
];
seq_lens
=
args
[
4
];
}
}
...
@@ -816,7 +813,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -816,7 +813,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process bias
// process bias
instruction_ref
bias_forward
=
m
.
end
();
instruction_ref
bias_forward
=
m
.
end
();
instruction_ref
bias_reverse
=
m
.
end
();
instruction_ref
bias_reverse
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
4
and
not
args
[
3
]
->
is_
undefined
()
)
{
{
bias_forward
=
m
.
insert_instruction
(
bias_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
3
]);
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
3
]);
...
@@ -827,7 +824,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -827,7 +824,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process intial hidden state, it is the 6th argument
// process intial hidden state, it is the 6th argument
instruction_ref
ih_forward
{};
instruction_ref
ih_forward
{};
instruction_ref
ih_reverse
{};
instruction_ref
ih_reverse
{};
if
(
args
.
size
()
>=
6
&&
args
[
5
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
6
and
not
args
[
5
]
->
is_
undefined
()
)
{
{
ih_forward
=
m
.
insert_instruction
(
ih_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
5
]);
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
5
]);
...
@@ -843,7 +840,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -843,7 +840,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process initial cell value
// process initial cell value
instruction_ref
ic_forward
{};
instruction_ref
ic_forward
{};
instruction_ref
ic_reverse
{};
instruction_ref
ic_reverse
{};
if
(
args
.
size
()
>=
7
&&
args
[
6
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
7
and
not
args
[
6
]
->
is_
undefined
()
)
{
{
ic_forward
=
m
.
insert_instruction
(
ic_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
6
]);
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
6
]);
...
@@ -859,7 +856,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -859,7 +856,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process weight of the peephole
// process weight of the peephole
instruction_ref
pph_forward
=
m
.
end
();
instruction_ref
pph_forward
=
m
.
end
();
instruction_ref
pph_reverse
=
m
.
end
();
instruction_ref
pph_reverse
=
m
.
end
();
if
(
args
.
size
()
==
8
&&
args
[
7
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
==
8
and
not
args
[
7
]
->
is_
undefined
()
)
{
{
pph_forward
=
m
.
insert_instruction
(
pph_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
7
]);
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
7
]);
...
@@ -943,14 +940,14 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -943,14 +940,14 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// bias
// bias
instruction_ref
bias
=
m
.
end
();
instruction_ref
bias
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
4
and
not
args
[
3
]
->
is_
undefined
()
)
{
{
bias
=
args
[
3
];
bias
=
args
[
3
];
}
}
// initial hidden state
// initial hidden state
instruction_ref
ih
{};
instruction_ref
ih
{};
if
(
args
.
size
()
>=
6
&&
args
[
5
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
6
and
not
args
[
5
]
->
is_
undefined
()
)
{
{
ih
=
args
[
5
];
ih
=
args
[
5
];
}
}
...
@@ -961,7 +958,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -961,7 +958,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// initial cell value
// initial cell value
instruction_ref
ic
{};
instruction_ref
ic
{};
if
(
args
.
size
()
>=
7
&&
args
[
6
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
7
and
not
args
[
6
]
->
is_
undefined
()
)
{
{
ic
=
args
[
6
];
ic
=
args
[
6
];
}
}
...
@@ -972,7 +969,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -972,7 +969,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process weight of the peephole
// process weight of the peephole
instruction_ref
pph
=
m
.
end
();
instruction_ref
pph
=
m
.
end
();
if
(
args
.
size
()
==
8
&&
args
[
7
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
==
8
and
not
args
[
7
]
->
is_
undefined
()
)
{
{
pph
=
args
[
7
];
pph
=
args
[
7
];
}
}
...
...
src/shape.cpp
View file @
dae94657
...
@@ -71,6 +71,19 @@ struct shape_impl
...
@@ -71,6 +71,19 @@ struct shape_impl
{
{
}
}
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
std
::
size_t
>
mins
,
std
::
vector
<
std
::
size_t
>
maxes
,
std
::
vector
<
std
::
size_t
>
opts
)
:
m_type
(
t
)
{
assert
(
mins
.
size
()
==
maxes
.
size
()
and
maxes
.
size
()
==
opts
.
size
());
for
(
size_t
i
=
0
;
i
<
mins
.
size
();
++
i
)
{
m_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
mins
[
i
],
maxes
[
i
],
opts
[
i
]});
}
}
shape_impl
(
const
std
::
vector
<
shape
>&
subs
)
:
m_type
(
shape
::
tuple_type
),
m_shapes
(
subs
)
{}
shape_impl
(
const
std
::
vector
<
shape
>&
subs
)
:
m_type
(
shape
::
tuple_type
),
m_shapes
(
subs
)
{}
shape
::
type_t
m_type
;
shape
::
type_t
m_type
;
...
@@ -224,6 +237,14 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
...
@@ -224,6 +237,14 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
{
{
}
}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
mins
,
std
::
vector
<
std
::
size_t
>
maxes
,
std
::
vector
<
std
::
size_t
>
opts
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
mins
),
std
::
move
(
maxes
),
std
::
move
(
opts
)))
{
}
shape
::
shape
(
const
std
::
vector
<
shape
>&
subs
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
subs
))
{}
shape
::
shape
(
const
std
::
vector
<
shape
>&
subs
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
subs
))
{}
shape
::
shape
(
std
::
shared_ptr
<
shape_impl
>
pimpl
)
:
impl
(
std
::
move
(
pimpl
))
{}
shape
::
shape
(
std
::
shared_ptr
<
shape_impl
>
pimpl
)
:
impl
(
std
::
move
(
pimpl
))
{}
...
@@ -244,6 +265,15 @@ const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
...
@@ -244,6 +265,15 @@ const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
impl
->
m_strides
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
impl
->
m_strides
;
}
std
::
size_t
shape
::
ndim
()
const
{
if
(
this
->
dynamic
())
{
return
dyn_dims
().
size
();
}
return
lens
().
size
();
}
std
::
size_t
shape
::
elements
()
const
{
return
impl
->
elements
();
}
std
::
size_t
shape
::
elements
()
const
{
return
impl
->
elements
();
}
std
::
size_t
shape
::
bytes
()
const
std
::
size_t
shape
::
bytes
()
const
...
@@ -437,6 +467,16 @@ shape shape::with_type(type_t t) const
...
@@ -437,6 +467,16 @@ shape shape::with_type(type_t t) const
return
{
c
};
return
{
c
};
}
}
shape
shape
::
to_dynamic
()
const
{
if
(
this
->
dynamic
())
{
return
*
this
;
}
std
::
vector
<
std
::
size_t
>
zeroes
(
this
->
ndim
(),
0
);
return
{
type
(),
lens
(),
lens
(),
zeroes
};
}
std
::
size_t
shape
::
element_space
()
const
{
return
impl
->
element_space
();
}
std
::
size_t
shape
::
element_space
()
const
{
return
impl
->
element_space
();
}
std
::
string
shape
::
type_string
()
const
{
return
name
(
this
->
type
());
}
std
::
string
shape
::
type_string
()
const
{
return
name
(
this
->
type
());
}
...
@@ -464,15 +504,11 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max;
...
@@ -464,15 +504,11 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max;
bool
shape
::
dynamic_dimension
::
has_optimal
()
const
{
return
opt
!=
0
;
}
bool
shape
::
dynamic_dimension
::
has_optimal
()
const
{
return
opt
!=
0
;
}
template
<
class
Self
,
class
F
>
auto
shape
::
dynamic_dimension
::
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
min
,
"min"
),
f
(
self
.
max
,
"max"
),
f
(
self
.
opt
,
"opt"
));
}
bool
operator
==
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
bool
operator
==
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
{
return
(
x
.
min
==
y
.
min
and
x
.
max
==
y
.
max
and
x
.
opt
==
y
.
opt
);
// don't check opt if both are fixed
return
(
x
.
min
==
y
.
min
and
x
.
max
==
y
.
max
and
((
x
.
is_fixed
()
and
y
.
is_fixed
())
or
(
x
.
opt
==
y
.
opt
)));
}
}
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
...
@@ -485,6 +521,14 @@ std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
...
@@ -485,6 +521,14 @@ std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
return
os
;
return
os
;
}
}
bool
operator
==
(
const
shape
::
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
)
{
return
x
.
min
==
y
and
x
.
max
==
y
;
}
bool
operator
==
(
const
std
::
size_t
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
return
y
==
x
;
}
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
)
{
return
not
(
x
==
y
);
}
bool
operator
!=
(
const
std
::
size_t
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
return
not
(
x
==
y
);
}
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
)
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
)
{
{
if
(
x
.
dynamic
()
and
y
.
dynamic
())
if
(
x
.
dynamic
()
and
y
.
dynamic
())
...
...
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment