Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5ec8f913
Commit
5ec8f913
authored
Sep 13, 2022
by
Ted Themistokleous
Committed by
Ted Themistokleous
Sep 13, 2022
Browse files
Merge branch 'develop' into simplify_1_mul_div_ops
parents
32d69e8e
d78bcdfb
Changes
183
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
174 additions
and
92 deletions
+174
-92
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+1
-1
src/include/migraphx/raw_data.hpp
src/include/migraphx/raw_data.hpp
+4
-4
src/include/migraphx/reflect.hpp
src/include/migraphx/reflect.hpp
+1
-1
src/include/migraphx/requires.hpp
src/include/migraphx/requires.hpp
+1
-1
src/include/migraphx/rewrite_gelu.hpp
src/include/migraphx/rewrite_gelu.hpp
+48
-0
src/include/migraphx/schedule_model.hpp
src/include/migraphx/schedule_model.hpp
+2
-2
src/include/migraphx/stream_model.hpp
src/include/migraphx/stream_model.hpp
+2
-2
src/include/migraphx/streamutils.hpp
src/include/migraphx/streamutils.hpp
+1
-1
src/include/migraphx/supported_segments.hpp
src/include/migraphx/supported_segments.hpp
+12
-4
src/include/migraphx/target.hpp
src/include/migraphx/target.hpp
+35
-32
src/include/migraphx/target_assignments.hpp
src/include/migraphx/target_assignments.hpp
+14
-3
src/include/migraphx/tensor_view.hpp
src/include/migraphx/tensor_view.hpp
+10
-10
src/include/migraphx/tune_axis.hpp
src/include/migraphx/tune_axis.hpp
+1
-1
src/instruction.cpp
src/instruction.cpp
+6
-6
src/make_op.cpp
src/make_op.cpp
+5
-0
src/module.cpp
src/module.cpp
+12
-9
src/normalize_attributes.cpp
src/normalize_attributes.cpp
+6
-5
src/onnx/include/migraphx/onnx/onnx_parser.hpp
src/onnx/include/migraphx/onnx/onnx_parser.hpp
+1
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+7
-0
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+5
-10
No files found.
src/include/migraphx/program.hpp
View file @
5ec8f913
...
@@ -124,7 +124,7 @@ struct program
...
@@ -124,7 +124,7 @@ struct program
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
not
(
x
==
y
);
}
// module related api
// module related api
module
*
create_module
(
const
std
::
string
&
name
);
module
*
create_module
(
const
std
::
string
&
name
);
...
...
src/include/migraphx/raw_data.hpp
View file @
5ec8f913
...
@@ -147,7 +147,7 @@ struct raw_data : raw_data_base
...
@@ -147,7 +147,7 @@ struct raw_data : raw_data_base
template
<
class
T
>
template
<
class
T
>
bool
matches
()
const
bool
matches
()
const
{
{
return
is_data_ptr
<
T
>
{}
||
return
is_data_ptr
<
T
>
{}
or
self
->
get_shape
().
type
()
==
migraphx
::
shape
::
get_type
<
get_data_type
<
T
>>
{};
self
->
get_shape
().
type
()
==
migraphx
::
shape
::
get_type
<
get_data_type
<
T
>>
{};
}
}
...
@@ -232,7 +232,7 @@ auto visit_all(T&& x, Ts&&... xs)
...
@@ -232,7 +232,7 @@ auto visit_all(T&& x, Ts&&... xs)
{
{
auto
&&
s
=
x
.
get_shape
();
auto
&&
s
=
x
.
get_shape
();
std
::
initializer_list
<
shape
::
type_t
>
types
=
{
xs
.
get_shape
().
type
()...};
std
::
initializer_list
<
shape
::
type_t
>
types
=
{
xs
.
get_shape
().
type
()...};
if
(
!
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
if
(
not
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
MIGRAPHX_THROW
(
"Types must be the same"
);
return
[
&
](
auto
...
vs
)
{
detail
::
visit_all_pack
(
s
,
vs
...)(
x
,
xs
...);
};
return
[
&
](
auto
...
vs
)
{
detail
::
visit_all_pack
(
s
,
vs
...)(
x
,
xs
...);
};
}
}
...
@@ -241,7 +241,7 @@ template <class T>
...
@@ -241,7 +241,7 @@ template <class T>
auto
visit_all
(
const
std
::
vector
<
T
>&
x
)
auto
visit_all
(
const
std
::
vector
<
T
>&
x
)
{
{
auto
&&
s
=
x
.
front
().
get_shape
();
auto
&&
s
=
x
.
front
().
get_shape
();
if
(
!
std
::
all_of
(
if
(
not
std
::
all_of
(
x
.
begin
(),
x
.
end
(),
[
&
](
const
T
&
y
)
{
return
y
.
get_shape
().
type
()
==
s
.
type
();
}))
x
.
begin
(),
x
.
end
(),
[
&
](
const
T
&
y
)
{
return
y
.
get_shape
().
type
()
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
MIGRAPHX_THROW
(
"Types must be the same"
);
return
[
&
](
auto
v
)
{
return
[
&
](
auto
v
)
{
...
@@ -281,7 +281,7 @@ template <class T,
...
@@ -281,7 +281,7 @@ template <class T,
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
bool
operator
!=
(
const
T
&
x
,
const
U
&
y
)
bool
operator
!=
(
const
T
&
x
,
const
U
&
y
)
{
{
return
!
(
x
==
y
);
return
not
(
x
==
y
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/reflect.hpp
View file @
5ec8f913
...
@@ -129,7 +129,7 @@ template <class T>
...
@@ -129,7 +129,7 @@ template <class T>
struct
reflect_equality
struct
reflect_equality
{
{
friend
bool
operator
==
(
const
T
&
x
,
const
T
&
y
)
{
return
reflect_tie
(
x
)
==
reflect_tie
(
y
);
}
friend
bool
operator
==
(
const
T
&
x
,
const
T
&
y
)
{
return
reflect_tie
(
x
)
==
reflect_tie
(
y
);
}
friend
bool
operator
!=
(
const
T
&
x
,
const
T
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
T
&
x
,
const
T
&
y
)
{
return
not
(
x
==
y
);
}
};
};
template
<
class
T
>
template
<
class
T
>
...
...
src/include/migraphx/requires.hpp
View file @
5ec8f913
...
@@ -31,7 +31,7 @@ namespace migraphx {
...
@@ -31,7 +31,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
bool
...
Bs
>
template
<
bool
...
Bs
>
struct
and_
:
std
::
is_same
<
and_
<
Bs
...
>
,
and_
<
(
Bs
||
true
)...
>>
// NOLINT
struct
and_
:
std
::
is_same
<
and_
<
Bs
...
>
,
and_
<
(
Bs
or
true
)...
>>
// NOLINT
{
{
};
};
...
...
src/include/migraphx/rewrite_gelu.hpp
0 → 100644
View file @
5ec8f913
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
/**
* Rewrite gelu standard formula as the sigmoid approximation formula
*/
struct
rewrite_gelu
{
std
::
string
name
()
const
{
return
"rewrite_gelu"
;
}
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/schedule_model.hpp
View file @
5ec8f913
...
@@ -208,7 +208,7 @@ struct schedule_model
...
@@ -208,7 +208,7 @@ struct schedule_model
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -274,7 +274,7 @@ struct schedule_model
...
@@ -274,7 +274,7 @@ struct schedule_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/stream_model.hpp
View file @
5ec8f913
...
@@ -216,7 +216,7 @@ struct stream_model
...
@@ -216,7 +216,7 @@ struct stream_model
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -288,7 +288,7 @@ struct stream_model
...
@@ -288,7 +288,7 @@ struct stream_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/streamutils.hpp
View file @
5ec8f913
...
@@ -41,7 +41,7 @@ struct stream_range_container
...
@@ -41,7 +41,7 @@ struct stream_range_container
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
stream_range_container
&
sr
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
stream_range_container
&
sr
)
{
{
assert
(
sr
.
r
!=
nullptr
);
assert
(
sr
.
r
!=
nullptr
);
if
(
!
sr
.
r
->
empty
())
if
(
not
sr
.
r
->
empty
())
{
{
os
<<
sr
.
r
->
front
();
os
<<
sr
.
r
->
front
();
std
::
for_each
(
std
::
for_each
(
...
...
src/
target_assign
ments.
c
pp
→
src/
include/migraphx/supported_seg
ments.
h
pp
View file @
5ec8f913
...
@@ -21,16 +21,24 @@
...
@@ -21,16 +21,24 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SUPPORTED_SEGMENTS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SUPPORTED_SEGMENTS_HPP
#include <migraphx/target_assignments.hpp>
#include <unordered_set>
#include <migraphx/instruction_ref.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
target_assignments
::
add_assignment
(
instruction_ref
ins
,
const
std
::
string
&
target
)
struct
supported_segment
{
{
assignments
.
emplace
(
ins
,
target
);
std
::
unordered_set
<
instruction_ref
>
instructions
;
}
float
metric
;
};
using
supported_segments
=
std
::
vector
<
supported_segment
>
;
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SUPPORTED_SEGMENTS_HPP
src/include/migraphx/target.hpp
View file @
5ec8f913
...
@@ -37,8 +37,10 @@
...
@@ -37,8 +37,10 @@
#include <migraphx/compile_options.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/supported_segments.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -64,12 +66,12 @@ struct target
...
@@ -64,12 +66,12 @@ struct target
*/
*/
context
get_context
()
const
;
context
get_context
()
const
;
/**
/**
* @brief
Check how well an
instruction
is
supported on a target
with the given metric
* @brief
Get the ranges of
instruction
s that are
supported on a target
* @param
ins Instruction
to check
if it's
supported
* @param
module Module
to check
for
supported
instructions
* @param metric Used to define how the
return value
should be
interpret
ed
* @param metric Used to define how the
quality of the support
should be
measur
ed
* @return
T
he
value based on the chosen metric. Negative numbers mean unsupported
* @return
t
he
supported segments of the graph
*/
*/
float
is_supported
(
T
&
,
i
nst
ruction
_ref
ins
,
support_metric
m
)
const
;
supported_segments
target_
is_supported
(
T
&
,
co
nst
_module
_ref
mod
,
support_metric
m
etric
)
const
;
/**
/**
* @brief copy an argument to the current target.
* @brief copy an argument to the current target.
*
*
...
@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg)
...
@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg)
}
}
template
<
class
T
>
template
<
class
T
>
float
target_
is
_supported
(
T
&
,
i
nst
ruction
_ref
,
support_metric
)
supported_segments
target_
find
_supported
(
T
&
,
co
nst
_module
_ref
,
support_metric
)
{
{
return
0
;
return
{}
;
}
}
#ifdef TYPE_ERASED_DECLARATION
#ifdef TYPE_ERASED_DECLARATION
...
@@ -132,7 +134,7 @@ struct target
...
@@ -132,7 +134,7 @@ struct target
//
//
context
get_context
()
const
;
context
get_context
()
const
;
// (optional)
// (optional)
float
is
_supported
(
i
nst
ruction
_ref
ins
,
support_metric
m
)
const
;
supported_segments
find
_supported
(
co
nst
_module
_ref
mod
,
support_metric
m
)
const
;
// (optional)
// (optional)
argument
copy_to
(
const
argument
&
input
)
const
;
argument
copy_to
(
const
argument
&
input
)
const
;
// (optional)
// (optional)
...
@@ -224,10 +226,10 @@ struct target
...
@@ -224,10 +226,10 @@ struct target
return
(
*
this
).
private_detail_te_get_handle
().
get_context
();
return
(
*
this
).
private_detail_te_get_handle
().
get_context
();
}
}
float
is
_supported
(
i
nst
ruction
_ref
ins
,
support_metric
m
)
const
supported_segments
find
_supported
(
co
nst
_module
_ref
mod
,
support_metric
m
)
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
is
_supported
(
ins
,
m
);
return
(
*
this
).
private_detail_te_get_handle
().
find
_supported
(
mod
,
m
);
}
}
argument
copy_to
(
const
argument
&
input
)
const
argument
copy_to
(
const
argument
&
input
)
const
...
@@ -261,33 +263,33 @@ struct target
...
@@ -261,33 +263,33 @@ struct target
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
,
virtual
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
,
const
compile_options
&
options
)
const
=
0
;
const
compile_options
&
options
)
const
=
0
;
virtual
context
get_context
()
const
=
0
;
virtual
context
get_context
()
const
=
0
;
virtual
float
is
_supported
(
i
nst
ruction
_ref
ins
,
support_metric
m
)
const
=
0
;
virtual
supported_segments
find
_supported
(
co
nst
_module
_ref
mod
,
support_metric
m
)
const
=
0
;
virtual
argument
copy_to
(
const
argument
&
input
)
const
=
0
;
virtual
argument
copy_to
(
const
argument
&
input
)
const
=
0
;
virtual
argument
copy_from
(
const
argument
&
input
)
const
=
0
;
virtual
argument
copy_from
(
const
argument
&
input
)
const
=
0
;
virtual
argument
allocate
(
const
shape
&
s
)
const
=
0
;
virtual
argument
allocate
(
const
shape
&
s
)
const
=
0
;
};
};
template
<
class
T
>
template
<
class
T
>
static
auto
private_detail_te_default_
is
_supported
(
char
,
static
auto
private_detail_te_default_
find
_supported
(
char
,
T
&&
private_detail_te_self
,
T
&&
private_detail_te_self
,
instruction
_ref
ins
,
const_module
_ref
mod
,
support_metric
m
)
support_metric
m
)
->
decltype
(
private_detail_te_self
.
is
_supported
(
ins
,
m
))
->
decltype
(
private_detail_te_self
.
find
_supported
(
mod
,
m
))
{
{
return
private_detail_te_self
.
is
_supported
(
ins
,
m
);
return
private_detail_te_self
.
find
_supported
(
mod
,
m
);
}
}
template
<
class
T
>
template
<
class
T
>
static
float
private_detail_te_default_
is
_supported
(
float
,
static
supported_segments
private_detail_te_default_
find
_supported
(
float
,
T
&&
private_detail_te_self
,
T
&&
private_detail_te_self
,
instruction
_ref
ins
,
const_module
_ref
mod
,
support_metric
m
)
support_metric
m
)
{
{
return
target_
is
_supported
(
private_detail_te_self
,
ins
,
m
);
return
target_
find
_supported
(
private_detail_te_self
,
mod
,
m
);
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -349,7 +351,7 @@ struct target
...
@@ -349,7 +351,7 @@ struct target
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -372,10 +374,11 @@ struct target
...
@@ -372,10 +374,11 @@ struct target
context
get_context
()
const
override
{
return
private_detail_te_value
.
get_context
();
}
context
get_context
()
const
override
{
return
private_detail_te_value
.
get_context
();
}
float
is
_supported
(
i
nst
ruction
_ref
ins
,
support_metric
m
)
const
override
supported_segments
find
_supported
(
co
nst
_module
_ref
mod
,
support_metric
m
)
const
override
{
{
return
private_detail_te_default_is_supported
(
char
(
0
),
private_detail_te_value
,
ins
,
m
);
return
private_detail_te_default_find_supported
(
char
(
0
),
private_detail_te_value
,
mod
,
m
);
}
}
argument
copy_to
(
const
argument
&
input
)
const
override
argument
copy_to
(
const
argument
&
input
)
const
override
...
@@ -423,7 +426,7 @@ struct target
...
@@ -423,7 +426,7 @@ struct target
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/target_assignments.hpp
View file @
5ec8f913
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#include <unordered_map>
#include <unordered_map>
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/instruction_ref.hpp>
...
@@ -33,10 +34,20 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -33,10 +34,20 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
target_assignments
struct
target_assignments
{
{
void
add_assignment
(
instruction_ref
ins
,
const
std
::
string
&
target
);
using
iterator
=
std
::
unordered_map
<
instruction_ref
,
std
::
string
>::
const_iterator
;
using
value_type
=
std
::
pair
<
instruction_ref
,
std
::
string
>
;
auto
begin
()
const
{
return
assignments
.
cbegin
();
}
auto
size
()
const
{
return
assignments
.
size
();
}
auto
end
()
const
{
return
assignments
.
cend
();
}
auto
&
at
(
instruction_ref
ins
)
const
{
return
assignments
.
at
(
ins
);
}
auto
insert
(
iterator
it
,
const
std
::
pair
<
instruction_ref
,
std
::
string
>&
assignment
)
{
return
assignments
.
insert
(
it
,
assignment
);
}
auto
find
(
instruction_ref
ins
)
const
{
return
assignments
.
find
(
ins
);
}
auto
begin
()
const
{
return
assignments
.
begin
();
}
auto
end
()
const
{
return
assignments
.
end
();
}
private:
private:
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
assignments
;
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
assignments
;
...
...
src/include/migraphx/tensor_view.hpp
View file @
5ec8f913
...
@@ -67,7 +67,7 @@ struct tensor_view
...
@@ -67,7 +67,7 @@ struct tensor_view
const
shape
&
get_shape
()
const
{
return
this
->
m_shape
;
}
const
shape
&
get_shape
()
const
{
return
this
->
m_shape
;
}
bool
empty
()
const
{
return
m_data
==
nullptr
||
m_shape
.
lens
().
empty
();
}
bool
empty
()
const
{
return
m_data
==
nullptr
or
m_shape
.
lens
().
empty
();
}
std
::
size_t
size
()
const
{
return
m_shape
.
elements
();
}
std
::
size_t
size
()
const
{
return
m_shape
.
elements
();
}
...
@@ -109,37 +109,37 @@ struct tensor_view
...
@@ -109,37 +109,37 @@ struct tensor_view
T
&
operator
[](
std
::
size_t
i
)
T
&
operator
[](
std
::
size_t
i
)
{
{
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
assert
(
not
this
->
empty
()
&&
i
<
this
->
size
());
return
m_data
[
m_shape
.
index
(
i
)];
return
m_data
[
m_shape
.
index
(
i
)];
}
}
const
T
&
operator
[](
std
::
size_t
i
)
const
const
T
&
operator
[](
std
::
size_t
i
)
const
{
{
assert
(
!
this
->
empty
()
&&
i
<
this
->
size
());
assert
(
not
this
->
empty
()
&&
i
<
this
->
size
());
return
m_data
[
m_shape
.
index
(
i
)];
return
m_data
[
m_shape
.
index
(
i
)];
}
}
T
&
front
()
T
&
front
()
{
{
assert
(
!
this
->
empty
());
assert
(
not
this
->
empty
());
return
m_data
[
0
];
return
m_data
[
0
];
}
}
const
T
&
front
()
const
const
T
&
front
()
const
{
{
assert
(
!
this
->
empty
());
assert
(
not
this
->
empty
());
return
m_data
[
0
];
return
m_data
[
0
];
}
}
T
&
back
()
T
&
back
()
{
{
assert
(
!
this
->
empty
());
assert
(
not
this
->
empty
());
return
m_data
[
m_shape
.
index
(
this
->
size
()
-
1
)];
return
m_data
[
m_shape
.
index
(
this
->
size
()
-
1
)];
}
}
const
T
&
back
()
const
const
T
&
back
()
const
{
{
assert
(
!
this
->
empty
());
assert
(
not
this
->
empty
());
return
m_data
[
m_shape
.
index
(
this
->
size
()
-
1
)];
return
m_data
[
m_shape
.
index
(
this
->
size
()
-
1
)];
}
}
...
@@ -159,7 +159,7 @@ struct tensor_view
...
@@ -159,7 +159,7 @@ struct tensor_view
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
{
{
if
(
!
x
.
empty
())
if
(
not
x
.
empty
())
{
{
os
<<
as_number
(
x
.
front
());
os
<<
as_number
(
x
.
front
());
for
(
std
::
size_t
i
=
1
;
i
<
x
.
m_shape
.
elements
();
i
++
)
for
(
std
::
size_t
i
=
1
;
i
<
x
.
m_shape
.
elements
();
i
++
)
...
@@ -182,7 +182,7 @@ bool operator==(const tensor_view<T>& x, const tensor_view<U>& y)
...
@@ -182,7 +182,7 @@ bool operator==(const tensor_view<T>& x, const tensor_view<U>& y)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
x
.
get_shape
().
elements
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
x
.
get_shape
().
elements
();
i
++
)
{
{
if
(
!
float_equal
(
x
[
i
],
y
[
i
]))
if
(
not
float_equal
(
x
[
i
],
y
[
i
]))
return
false
;
return
false
;
}
}
return
true
;
return
true
;
...
@@ -193,7 +193,7 @@ bool operator==(const tensor_view<T>& x, const tensor_view<U>& y)
...
@@ -193,7 +193,7 @@ bool operator==(const tensor_view<T>& x, const tensor_view<U>& y)
template
<
class
T
,
class
U
>
template
<
class
T
,
class
U
>
bool
operator
!=
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
U
>&
y
)
bool
operator
!=
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
U
>&
y
)
{
{
return
!
(
x
==
y
);
return
not
(
x
==
y
);
}
}
template
<
class
T
>
template
<
class
T
>
...
...
src/include/migraphx/tune_axis.hpp
View file @
5ec8f913
...
@@ -34,7 +34,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -34,7 +34,7 @@ inline namespace MIGRAPHX_INLINE_NS {
inline
int
tune_axis
(
const
int
n_dim
,
const
int
axis
,
const
std
::
string
&
op_name
=
"OPERATOR"
)
inline
int
tune_axis
(
const
int
n_dim
,
const
int
axis
,
const
std
::
string
&
op_name
=
"OPERATOR"
)
{
{
if
(
axis
>=
n_dim
||
std
::
abs
(
axis
)
>
n_dim
)
if
(
axis
>=
n_dim
or
std
::
abs
(
axis
)
>
n_dim
)
{
{
MIGRAPHX_THROW
(
to_upper
(
op_name
)
+
": axis is out of range."
);
MIGRAPHX_THROW
(
to_upper
(
op_name
)
+
": axis is out of range."
);
}
}
...
...
src/instruction.cpp
View file @
5ec8f913
...
@@ -176,13 +176,13 @@ bool operator==(const instruction& x, const instruction& y)
...
@@ -176,13 +176,13 @@ bool operator==(const instruction& x, const instruction& y)
return
true
;
return
true
;
}
}
bool
operator
!=
(
const
instruction
&
x
,
const
instruction
&
y
)
{
return
!
(
x
==
y
);
}
bool
operator
!=
(
const
instruction
&
x
,
const
instruction
&
y
)
{
return
not
(
x
==
y
);
}
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
bool
operator
!=
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
!
(
i
==
ref
);
}
bool
operator
!=
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
not
(
i
==
ref
);
}
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
!
(
i
==
ref
);
}
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
not
(
i
==
ref
);
}
void
instruction
::
add_output
(
instruction_ref
ins
)
void
instruction
::
add_output
(
instruction_ref
ins
)
{
{
...
@@ -361,7 +361,7 @@ void instruction::print(std::ostream& os,
...
@@ -361,7 +361,7 @@ void instruction::print(std::ostream& os,
os
<<
"{"
<<
ins
->
get_literal
()
<<
"}"
;
os
<<
"{"
<<
ins
->
get_literal
()
<<
"}"
;
}
}
if
(
!
ins
->
inputs
().
empty
())
if
(
not
ins
->
inputs
().
empty
())
{
{
char
delim
=
'('
;
char
delim
=
'('
;
for
(
auto
&&
arg
:
ins
->
inputs
())
for
(
auto
&&
arg
:
ins
->
inputs
())
...
@@ -374,7 +374,7 @@ void instruction::print(std::ostream& os,
...
@@ -374,7 +374,7 @@ void instruction::print(std::ostream& os,
}
}
// print module inputs
// print module inputs
if
(
!
ins
->
module_inputs
().
empty
())
if
(
not
ins
->
module_inputs
().
empty
())
{
{
std
::
string
delim
=
", ["
;
std
::
string
delim
=
", ["
;
for
(
auto
&&
mod_arg
:
ins
->
module_inputs
())
for
(
auto
&&
mod_arg
:
ins
->
module_inputs
())
...
@@ -446,7 +446,7 @@ operation instruction::normalized_operator() const
...
@@ -446,7 +446,7 @@ operation instruction::normalized_operator() const
if
(
this
->
need_normalization
())
if
(
this
->
need_normalization
())
{
{
auto
s
=
this
->
inputs
().
front
()
->
get_shape
();
auto
s
=
this
->
inputs
().
front
()
->
get_shape
();
if
(
!
normalize_attributes
(
o
,
s
.
max_lens
()))
if
(
not
normalize_attributes
(
o
,
s
.
max_lens
()))
return
this
->
get_operator
();
return
this
->
get_operator
();
}
}
return
o
;
return
o
;
...
...
src/make_op.cpp
View file @
5ec8f913
...
@@ -64,5 +64,10 @@ operation make_op_from_value(const std::string& name, const value& v)
...
@@ -64,5 +64,10 @@ operation make_op_from_value(const std::string& name, const value& v)
});
});
}
}
operation
make_json_op
(
const
std
::
string
&
name
,
const
std
::
string
&
s
)
{
return
make_op
(
name
,
from_json_string
(
convert_to_json
(
s
)));
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/module.cpp
View file @
5ec8f913
...
@@ -141,12 +141,12 @@ void module::set_bypass(bool b) { impl->bypass = b; }
...
@@ -141,12 +141,12 @@ void module::set_bypass(bool b) { impl->bypass = b; }
void
module
::
assign
(
const
module
&
m
)
void
module
::
assign
(
const
module
&
m
)
{
{
// copy the impl
// copy the impl
if
(
!
impl
)
if
(
not
impl
)
impl
=
std
::
make_unique
<
module_impl
>
();
impl
=
std
::
make_unique
<
module_impl
>
();
*
impl
=
*
m
.
impl
;
*
impl
=
*
m
.
impl
;
// clear instructions
// clear instructions
if
(
!
impl
->
instructions
.
empty
())
if
(
not
impl
->
instructions
.
empty
())
{
{
impl
->
clear
();
impl
->
clear
();
}
}
...
@@ -346,7 +346,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
...
@@ -346,7 +346,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
assert
(
out
->
valid
(
begin
()));
assert
(
out
->
valid
(
begin
()));
}
}
// Replacement should not be dead code unless its the last instruction
// Replacement should not be dead code unless its the last instruction
assert
(
!
rep
->
outputs
().
empty
()
or
rep
==
std
::
prev
(
end
()));
assert
(
not
rep
->
outputs
().
empty
()
or
rep
==
std
::
prev
(
end
()));
// Output of the original instruction should only be the replacement or empty
// Output of the original instruction should only be the replacement or empty
assert
(
ins
->
outputs
().
empty
()
or
std
::
all_of
(
ins
->
outputs
().
begin
(),
assert
(
ins
->
outputs
().
empty
()
or
std
::
all_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
ins
->
outputs
().
end
(),
...
@@ -598,7 +598,7 @@ instruction_ref module::validate() const
...
@@ -598,7 +598,7 @@ instruction_ref module::validate() const
auto
inputs
=
i
.
inputs
();
auto
inputs
=
i
.
inputs
();
bool
check_order
=
std
::
all_of
(
bool
check_order
=
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
in
)
{
return
has_instruction
(
in
);
});
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
in
)
{
return
has_instruction
(
in
);
});
return
!
i
.
valid
(
impl
->
instructions
.
begin
(),
check_order
);
return
not
i
.
valid
(
impl
->
instructions
.
begin
(),
check_order
);
});
});
}
}
...
@@ -754,7 +754,7 @@ void module::print_graph(std::ostream& os, bool brief) const
...
@@ -754,7 +754,7 @@ void module::print_graph(std::ostream& os, bool brief) const
label
=
to_string
(
ins
->
get_operator
());
label
=
to_string
(
ins
->
get_operator
());
os
<<
"
\t
"
<<
enclose_name
(
ins_names
.
at
(
ins
))
<<
"[label="
<<
enclose_name
(
label
)
<<
"]"
;
os
<<
"
\t
"
<<
enclose_name
(
ins_names
.
at
(
ins
))
<<
"[label="
<<
enclose_name
(
label
)
<<
"]"
;
os
<<
";"
<<
std
::
endl
;
os
<<
";"
<<
std
::
endl
;
if
(
!
ins
->
inputs
().
empty
())
if
(
not
ins
->
inputs
().
empty
())
{
{
for
(
auto
&&
arg
:
ins
->
inputs
())
for
(
auto
&&
arg
:
ins
->
inputs
())
{
{
...
@@ -788,12 +788,15 @@ static std::string cpp_var_name(const std::string& name)
...
@@ -788,12 +788,15 @@ static std::string cpp_var_name(const std::string& name)
static
void
print_make_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
static
void
print_make_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
{
os
<<
"migraphx::make_op("
<<
enclose_name
(
op
.
name
());
auto
v
=
op
.
to_value
();
auto
v
=
op
.
to_value
();
if
(
not
v
.
empty
())
if
(
not
v
.
empty
())
{
{
os
<<
", "
os
<<
"migraphx::make_json_op("
<<
enclose_name
(
op
.
name
());
<<
"migraphx::from_json_string("
<<
enclose_name
(
to_json_string
(
v
))
<<
")"
;
os
<<
", "
<<
enclose_name
(
to_json_string
(
v
));
}
else
{
os
<<
"migraphx::make_op("
<<
enclose_name
(
op
.
name
());
}
}
os
<<
")"
;
os
<<
")"
;
}
}
...
@@ -905,7 +908,7 @@ module& module::sort()
...
@@ -905,7 +908,7 @@ module& module::sort()
this
->
move_instruction
(
ins
,
this
->
begin
());
this
->
move_instruction
(
ins
,
this
->
begin
());
for
(
auto
child
:
ins
->
inputs
())
for
(
auto
child
:
ins
->
inputs
())
{
{
if
(
!
contains
(
this
->
impl
->
instructions
,
child
))
if
(
not
contains
(
this
->
impl
->
instructions
,
child
))
{
{
continue
;
continue
;
}
}
...
...
src/normalize_attributes.cpp
View file @
5ec8f913
...
@@ -79,14 +79,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
...
@@ -79,14 +79,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
{
if
(
contains
(
vec_attrs
,
op
::
normalize_attribute
::
include_max
))
if
(
contains
(
vec_attrs
,
op
::
normalize_attribute
::
include_max
))
{
{
if
(
!
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less_equal
<>
{}))
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less_equal
<>
{}))
{
{
MIGRAPHX_THROW
(
"TUNE_VECTOR: value out of range!"
);
MIGRAPHX_THROW
(
"TUNE_VECTOR: value out of range!"
);
}
}
}
}
else
else
{
{
if
(
!
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less
<>
{}))
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less
<>
{}))
{
{
MIGRAPHX_THROW
(
"TUNE_VECTOR: value out of range!"
);
MIGRAPHX_THROW
(
"TUNE_VECTOR: value out of range!"
);
}
}
...
@@ -118,14 +118,15 @@ auto tune_attribute(const std::vector<int64_t>& vec,
...
@@ -118,14 +118,15 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
{
if
(
contains
(
vec_attrs
,
op
::
normalize_attribute
::
include_min
))
if
(
contains
(
vec_attrs
,
op
::
normalize_attribute
::
include_min
))
{
{
if
(
!
std
::
equal
(
min_vals
.
begin
(),
min_vals
.
end
(),
result
.
begin
(),
std
::
less_equal
<>
{}))
if
(
not
std
::
equal
(
min_vals
.
begin
(),
min_vals
.
end
(),
result
.
begin
(),
std
::
less_equal
<>
{}))
{
{
MIGRAPHX_THROW
(
"TUNE_VECTOR: attribute out of range!"
);
MIGRAPHX_THROW
(
"TUNE_VECTOR: attribute out of range!"
);
}
}
}
}
else
else
{
{
if
(
!
std
::
equal
(
result
.
begin
(),
result
.
end
(),
min_vals
.
begin
(),
std
::
less
<>
{}))
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
min_vals
.
begin
(),
std
::
less
<>
{}))
{
{
MIGRAPHX_THROW
(
"TUNE_VECTOR: attribute out of range!"
);
MIGRAPHX_THROW
(
"TUNE_VECTOR: attribute out of range!"
);
}
}
...
@@ -174,7 +175,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
...
@@ -174,7 +175,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
tuned
=
true
;
tuned
=
true
;
}
}
}
}
if
(
!
attrs
.
contains
(
"normalize_axes"
))
if
(
not
attrs
.
contains
(
"normalize_axes"
))
{
{
return
tuned
;
return
tuned
;
}
}
...
...
src/onnx/include/migraphx/onnx/onnx_parser.hpp
View file @
5ec8f913
...
@@ -97,6 +97,7 @@ struct onnx_parser
...
@@ -97,6 +97,7 @@ struct onnx_parser
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
,
0
};
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
,
0
};
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
shape
::
dynamic_dimension
>>
map_dyn_input_dims
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
shape
::
dynamic_dimension
>>
map_dyn_input_dims
;
bool
use_dyn_output
=
false
;
bool
skip_unknown_operators
=
false
;
bool
skip_unknown_operators
=
false
;
int64_t
max_loop_iterations
=
10
;
int64_t
max_loop_iterations
=
10
;
int64_t
opset_version
=
13
;
int64_t
opset_version
=
13
;
...
...
src/onnx/onnx.cpp
View file @
5ec8f913
...
@@ -60,8 +60,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
...
@@ -60,8 +60,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{
{
parser
.
default_dyn_dim_value
=
options
.
default_dyn_dim_value
;
parser
.
default_dyn_dim_value
=
options
.
default_dyn_dim_value
;
}
}
if
(
not
options
.
map_input_dims
.
empty
()
and
not
options
.
map_dyn_input_dims
.
empty
())
{
MIGRAPHX_THROW
(
"PARSE_ONNX_FROM: both map_input_dims and map_dyn_input_dims non-empty, only"
"one should be used"
);
}
parser
.
skip_unknown_operators
=
options
.
skip_unknown_operators
;
parser
.
skip_unknown_operators
=
options
.
skip_unknown_operators
;
parser
.
max_loop_iterations
=
options
.
max_loop_iterations
;
parser
.
max_loop_iterations
=
options
.
max_loop_iterations
;
parser
.
use_dyn_output
=
options
.
use_dyn_output
;
if
(
options
.
print_program_on_error
)
if
(
options
.
print_program_on_error
)
{
{
...
@@ -80,6 +86,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
...
@@ -80,6 +86,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{
{
parser
.
parse_from
(
std
::
forward
<
Ts
>
(
xs
)...);
parser
.
parse_from
(
std
::
forward
<
Ts
>
(
xs
)...);
}
}
return
std
::
move
(
parser
.
prog
);
return
std
::
move
(
parser
.
prog
);
}
}
...
...
src/onnx/onnx_parser.cpp
View file @
5ec8f913
...
@@ -187,7 +187,7 @@ operation onnx_parser::load(const std::string& name, const node_info& info) cons
...
@@ -187,7 +187,7 @@ operation onnx_parser::load(const std::string& name, const node_info& info) cons
void
onnx_parser
::
parse_undefined
(
module
*
mod
,
const
std
::
string
&
name
)
void
onnx_parser
::
parse_undefined
(
module
*
mod
,
const
std
::
string
&
name
)
{
{
if
(
!
contains
(
instructions
,
name
))
if
(
not
contains
(
instructions
,
name
))
{
{
auto
ins
=
mod
->
add_instruction
(
make_op
(
"undefined"
));
auto
ins
=
mod
->
add_instruction
(
make_op
(
"undefined"
));
instructions
[
name
]
=
ins
;
instructions
[
name
]
=
ins
;
...
@@ -256,11 +256,6 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
...
@@ -256,11 +256,6 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
void
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
)
void
onnx_parser
::
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
)
{
{
if
(
not
map_input_dims
.
empty
()
and
not
map_dyn_input_dims
.
empty
())
{
MIGRAPHX_THROW
(
"PARSE_GRAPH: both map_input_dims and map_dyn_input_dims non-empty, only"
"one should be used"
);
}
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
mod_insts
;
for
(
auto
&&
f
:
graph
.
initializer
())
for
(
auto
&&
f
:
graph
.
initializer
())
{
{
...
@@ -272,7 +267,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
...
@@ -272,7 +267,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{
{
const
std
::
string
&
name
=
input
.
name
();
const
std
::
string
&
name
=
input
.
name
();
// input not in initializer_data, so it is a real input
// input not in initializer_data, so it is a real input
if
(
!
contains
(
mod_insts
,
name
))
if
(
not
contains
(
mod_insts
,
name
))
{
{
// ONNX specification does not specify how to deal with the
// ONNX specification does not specify how to deal with the
// scenario that a nested subgraph contains a parameter with the
// scenario that a nested subgraph contains a parameter with the
...
@@ -359,7 +354,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
...
@@ -359,7 +354,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
all_output_names
.
begin
(),
all_output_names
.
begin
(),
all_output_names
.
end
(),
all_output_names
.
end
(),
std
::
back_inserter
(
prog_output_names
),
std
::
back_inserter
(
prog_output_names
),
[
&
](
const
auto
&
name
)
{
return
!
(
name
.
empty
()
or
instructions
.
count
(
name
)
==
0
);
});
[
&
](
const
auto
&
name
)
{
return
not
(
name
.
empty
()
or
instructions
.
count
(
name
)
==
0
);
});
std
::
vector
<
instruction_ref
>
output_ins
;
std
::
vector
<
instruction_ref
>
output_ins
;
std
::
transform
(
prog_output_names
.
begin
(),
std
::
transform
(
prog_output_names
.
begin
(),
...
@@ -449,7 +444,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
...
@@ -449,7 +444,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
{
{
shape
::
type_t
shape_type
=
get_type
(
t
.
tensor_type
().
elem_type
());
shape
::
type_t
shape_type
=
get_type
(
t
.
tensor_type
().
elem_type
());
if
(
!
input_dims
.
empty
())
if
(
not
input_dims
.
empty
())
{
{
return
{
shape_type
,
input_dims
};
return
{
shape_type
,
input_dims
};
}
}
...
@@ -516,7 +511,7 @@ shape::type_t get_type(int dtype)
...
@@ -516,7 +511,7 @@ shape::type_t get_type(int dtype)
bool
is_type_float
(
shape
::
type_t
dtype
)
bool
is_type_float
(
shape
::
type_t
dtype
)
{
{
bool
r
=
false
;
bool
r
=
false
;
if
(
dtype
==
shape
::
float_type
||
dtype
==
shape
::
double_type
||
dtype
==
shape
::
half_type
)
if
(
dtype
==
shape
::
float_type
or
dtype
==
shape
::
double_type
or
dtype
==
shape
::
half_type
)
{
{
r
=
true
;
r
=
true
;
}
}
...
...
Prev
1
2
3
4
5
6
7
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment