Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
087c205e
"git@developer.sourcefind.cn:xdb4_94051/vllm.git" did not exist on "7addca5935c83806429d7ec557999a505e6f6a35"
Commit
087c205e
authored
Mar 04, 2019
by
Paul
Browse files
Merge from develop
parents
a3a9e469
e15b8333
Changes
255
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
979 additions
and
83 deletions
+979
-83
src/include/migraphx/functional.hpp
src/include/migraphx/functional.hpp
+6
-0
src/include/migraphx/instruction.hpp
src/include/migraphx/instruction.hpp
+6
-1
src/include/migraphx/iterator_for.hpp
src/include/migraphx/iterator_for.hpp
+2
-2
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+2
-2
src/include/migraphx/make_shared_array.hpp
src/include/migraphx/make_shared_array.hpp
+1
-1
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+0
-1
src/include/migraphx/onnx.hpp
src/include/migraphx/onnx.hpp
+18
-0
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+184
-10
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+461
-42
src/include/migraphx/par_dfor.hpp
src/include/migraphx/par_dfor.hpp
+51
-0
src/include/migraphx/par_for.hpp
src/include/migraphx/par_for.hpp
+82
-0
src/include/migraphx/pass.hpp
src/include/migraphx/pass.hpp
+8
-2
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+6
-0
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+65
-0
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+32
-9
src/include/migraphx/target.hpp
src/include/migraphx/target.hpp
+6
-0
src/include/migraphx/tensor_view.hpp
src/include/migraphx/tensor_view.hpp
+3
-1
src/include/migraphx/type_name.hpp
src/include/migraphx/type_name.hpp
+1
-1
src/instruction.cpp
src/instruction.cpp
+38
-9
src/onnx/mnist.cpp
src/onnx/mnist.cpp
+7
-2
No files found.
src/include/migraphx/functional.hpp
View file @
087c205e
...
@@ -94,6 +94,12 @@ constexpr void each_args(F)
...
@@ -94,6 +94,12 @@ constexpr void each_args(F)
{
{
}
}
template
<
class
F
,
class
T
>
auto
unpack
(
F
f
,
T
&
x
)
{
return
sequence_c
<
std
::
tuple_size
<
T
>
{}
>
([
&
](
auto
...
is
)
{
f
(
std
::
get
<
is
>
(
x
)...);
});
}
/// Implements a fix-point combinator
/// Implements a fix-point combinator
template
<
class
R
,
class
F
>
template
<
class
R
,
class
F
>
detail
::
fix_f
<
R
,
F
>
fix
(
F
f
)
detail
::
fix_f
<
R
,
F
>
fix
(
F
f
)
...
...
src/include/migraphx/instruction.hpp
View file @
087c205e
...
@@ -14,6 +14,7 @@ namespace migraphx {
...
@@ -14,6 +14,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
shape
compute_shape
(
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
args
);
shape
compute_shape
(
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
args
);
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
instruction_ref
>&
args
);
struct
instruction
struct
instruction
{
{
...
@@ -71,7 +72,11 @@ struct instruction
...
@@ -71,7 +72,11 @@ struct instruction
static
void
static
void
replace
(
instruction_ref
ins
,
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
);
replace
(
instruction_ref
ins
,
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
);
static
instruction_ref
get_output_alias
(
instruction_ref
ins
);
argument
eval
()
const
;
void
finalize
(
context
&
ctx
);
static
instruction_ref
get_output_alias
(
instruction_ref
ins
,
bool
shallow
=
false
);
private:
private:
// internal
// internal
...
...
src/include/migraphx/iterator_for.hpp
View file @
087c205e
...
@@ -17,9 +17,9 @@ struct iterator_for_range
...
@@ -17,9 +17,9 @@ struct iterator_for_range
struct
iterator
struct
iterator
{
{
base_iterator
i
;
base_iterator
i
;
base_iterator
operator
*
()
{
return
i
;
}
base_iterator
operator
*
()
const
{
return
i
;
}
base_iterator
operator
++
()
{
return
++
i
;
}
base_iterator
operator
++
()
{
return
++
i
;
}
bool
operator
!=
(
const
iterator
&
rhs
)
{
return
i
!=
rhs
.
i
;
}
bool
operator
!=
(
const
iterator
&
rhs
)
const
{
return
i
!=
rhs
.
i
;
}
};
};
iterator
begin
()
iterator
begin
()
...
...
src/include/migraphx/literal.hpp
View file @
087c205e
...
@@ -22,8 +22,8 @@ struct literal : raw_data<literal>
...
@@ -22,8 +22,8 @@ struct literal : raw_data<literal>
{
{
literal
()
{}
literal
()
{}
template
<
class
U
,
class
T
=
deduce
<
U
>
>
template
<
class
U
,
class
T
=
deduce
<
U
>
,
shape
::
type_t
ShapeType
=
shape
::
get_type
<
T
>
{}
>
literal
(
U
x
)
:
buffer
(
make_shared_array
<
char
>
(
sizeof
(
T
))),
m_shape
(
s
hape
::
get_type
<
T
>
{}
)
literal
(
U
x
)
:
buffer
(
make_shared_array
<
char
>
(
sizeof
(
T
))),
m_shape
(
S
hape
Type
)
{
{
static_assert
(
std
::
is_trivially_copyable
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivially_copyable
<
T
>
{},
"Literals can only be trivial types"
);
*
(
reinterpret_cast
<
T
*>
(
buffer
.
get
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
buffer
.
get
()))
=
x
;
...
...
src/include/migraphx/make_shared_array.hpp
View file @
087c205e
...
@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
template
<
typename
T
>
template
<
typename
T
>
std
::
shared_ptr
<
T
>
make_shared_array
(
size_t
size
)
std
::
shared_ptr
<
T
>
make_shared_array
(
size_t
size
)
{
{
return
std
::
shared_ptr
<
T
>
(
new
T
[
size
],
std
::
default_delete
<
T
[]
>
());
return
std
::
shared_ptr
<
T
>
(
new
T
[
size
],
std
::
default_delete
<
T
[]
>
());
// NOLINT
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/matcher.hpp
View file @
087c205e
...
@@ -214,7 +214,6 @@ void find_matches(program& p, Ms&&... ms)
...
@@ -214,7 +214,6 @@ void find_matches(program& p, Ms&&... ms)
bool
match
=
false
;
bool
match
=
false
;
each_args
(
each_args
(
[
&
](
auto
&&
m
)
{
[
&
](
auto
&&
m
)
{
// cppcheck-suppress knownConditionTrueFalse
if
(
match
)
if
(
match
)
return
;
return
;
auto
r
=
match_instruction
(
p
,
ins
,
m
.
matcher
());
auto
r
=
match_instruction
(
p
,
ins
,
m
.
matcher
());
...
...
src/include/migraphx/onnx.hpp
View file @
087c205e
...
@@ -7,6 +7,24 @@
...
@@ -7,6 +7,24 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
unknown
{
std
::
string
op
;
std
::
string
name
()
const
{
return
"unknown:"
+
op
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
{
if
(
input
.
empty
())
return
{};
else
return
input
.
front
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
os
<<
x
.
name
();
return
os
;
}
};
/// Create a program from an onnx file
/// Create a program from an onnx file
program
parse_onnx
(
const
std
::
string
&
name
);
program
parse_onnx
(
const
std
::
string
&
name
);
...
...
src/include/migraphx/operation.hpp
View file @
087c205e
...
@@ -7,17 +7,17 @@
...
@@ -7,17 +7,17 @@
#include <memory>
#include <memory>
#include <type_traits>
#include <type_traits>
#include <utility>
#include <utility>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
context
;
#ifdef DOXYGEN
#ifdef DOXYGEN
/// The operation interface represents an action an instruction will perform. All
/// The operation interface represents an action an instruction will perform. All
...
@@ -26,6 +26,8 @@ struct operation
...
@@ -26,6 +26,8 @@ struct operation
{
{
/// A unique name identifying the operation
/// A unique name identifying the operation
std
::
string
name
()
const
;
std
::
string
name
()
const
;
/// An optional method that can be used to finalize the operator before running
void
finalize
(
context
&
ctx
);
/// This is used to compute the resulting shape from an operation. If an
/// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an
/// operation cannot be run with input shapes, then it should throw an
/// exception.
/// exception.
...
@@ -53,6 +55,11 @@ struct operation
...
@@ -53,6 +55,11 @@ struct operation
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
};
};
/// Returns true if operation does not require a context to run compute
bool
is_context_free
(
const
operation
&
x
);
/// Returns true if the operation has a finalize method
bool
has_finalize
(
const
operation
&
x
);
#else
#else
namespace
operation_stream
{
namespace
operation_stream
{
...
@@ -89,7 +96,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
...
@@ -89,7 +96,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
}
// namespace operation_equal
}
// namespace operation_equal
template
<
class
T
>
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
const
T
&
x
,
context
&
ctx
,
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
...
@@ -99,6 +106,14 @@ auto compute_op(rank<1>,
...
@@ -99,6 +106,14 @@ auto compute_op(rank<1>,
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
context
&
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
{
...
@@ -110,7 +125,53 @@ template <class T>
...
@@ -110,7 +125,53 @@ template <class T>
argument
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
{
return
compute_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
return
compute_op
(
rank
<
2
>
{},
x
,
ctx
,
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
std
::
declval
<
context
&>
()),
output_shape
,
input
))
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable without a context: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
return
compute_op
(
rank
<
2
>
{},
x
,
output_shape
,
input
);
}
template
<
class
T
>
auto
is_context_free_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
),
std
::
true_type
{});
template
<
class
T
>
auto
is_context_free_op
(
rank
<
0
>
,
const
T
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
->
std
::
false_type
;
template
<
class
T
>
auto
is_context_free_op
(
const
T
&
x
)
->
decltype
(
is_context_free_op
(
rank
<
1
>
{},
x
,
std
::
declval
<
const
shape
&>
(),
std
::
declval
<
std
::
vector
<
argument
>>
()))
{
return
{};
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -132,15 +193,57 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
...
@@ -132,15 +193,57 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
return
output_alias_op
(
rank
<
1
>
{},
x
,
shapes
);
return
output_alias_op
(
rank
<
1
>
{},
x
,
shapes
);
}
}
template
<
class
T
>
auto
finalize_op
(
rank
<
1
>
,
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
->
decltype
(
x
.
finalize
(
auto_any_cast
(
ctx
),
output_shape
,
input
),
void
())
{
x
.
finalize
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
template
<
class
T
>
void
finalize_op
(
rank
<
0
>
,
T
&
,
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
{
}
template
<
class
T
>
void
finalize_op
(
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
{
finalize_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
}
template
<
class
T
>
auto
has_finalize_op
(
rank
<
1
>
,
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
->
decltype
(
x
.
finalize
(
auto_any_cast
(
ctx
),
output_shape
,
input
),
std
::
true_type
{});
template
<
class
T
>
auto
has_finalize_op
(
rank
<
0
>
,
T
&
,
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
->
std
::
false_type
;
template
<
class
T
>
auto
has_finalize_op
(
const
T
&
)
->
decltype
(
has_finalize_op
(
rank
<
1
>
{},
std
::
declval
<
T
&>
(),
std
::
declval
<
context
&>
(),
std
::
declval
<
const
shape
&>
(),
std
::
declval
<
std
::
vector
<
shape
>>
()))
{
return
{};
}
/*
/*
* Type-erased interface for:
* Type-erased interface for:
*
*
* struct operation
* struct operation
* {
* {
* std::string name() const;
* std::string name() const;
* bool is_context_free() const;
* bool has_finalize() const;
* int output_alias(const std::vector<shape>& input) const;
* int output_alias(const std::vector<shape>& input) const;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* shape compute_shape(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ;
* friend bool operator==(const operation & x,const operation & y) ;
* };
* };
...
@@ -210,12 +313,30 @@ struct operation
...
@@ -210,12 +313,30 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
name
();
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
}
bool
is_context_free
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
is_context_free
();
}
bool
has_finalize
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
has_finalize
();
}
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
output_alias
(
input
);
return
(
*
this
).
private_detail_te_get_handle
().
output_alias
(
input
);
}
}
void
finalize
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
finalize
(
ctx
,
output
,
input
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
@@ -228,6 +349,12 @@ struct operation
...
@@ -228,6 +349,12 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
ctx
,
output
,
input
);
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
ctx
,
output
,
input
);
}
}
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
output
,
input
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
{
assert
(
op
.
private_detail_te_handle_mem_var
);
assert
(
op
.
private_detail_te_handle_mem_var
);
...
@@ -240,6 +367,12 @@ struct operation
...
@@ -240,6 +367,12 @@ struct operation
return
x
.
private_detail_te_get_handle
().
operator
==
(
y
);
return
x
.
private_detail_te_get_handle
().
operator
==
(
y
);
}
}
friend
bool
is_shared
(
const
operation
&
private_detail_x
,
const
operation
&
private_detail_y
)
{
return
private_detail_x
.
private_detail_te_handle_mem_var
==
private_detail_y
.
private_detail_te_handle_mem_var
;
}
private:
private:
struct
private_detail_te_handle_base_type
struct
private_detail_te_handle_base_type
{
{
...
@@ -247,13 +380,18 @@ struct operation
...
@@ -247,13 +380,18 @@ struct operation
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
bool
is_context_free
()
const
=
0
;
virtual
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
bool
has_finalize
()
const
=
0
;
virtual
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
void
finalize
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
=
0
;
virtual
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
argument
virtual
argument
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
virtual
bool
operator
==
(
const
operation
&
y
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
bool
operator
==
(
const
operation
&
y
)
const
=
0
;
};
};
template
<
typename
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedT
>
...
@@ -286,12 +424,26 @@ struct operation
...
@@ -286,12 +424,26 @@ struct operation
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
bool
is_context_free
()
const
override
{
return
is_context_free_op
(
private_detail_te_value
);
}
bool
has_finalize
()
const
override
{
return
has_finalize_op
(
private_detail_te_value
);
}
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
override
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
override
{
{
return
output_alias_op
(
private_detail_te_value
,
input
);
return
output_alias_op
(
private_detail_te_value
,
input
);
}
}
void
finalize
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
override
{
finalize_op
(
private_detail_te_value
,
ctx
,
output
,
input
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
override
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
override
{
{
...
@@ -306,6 +458,12 @@ struct operation
...
@@ -306,6 +458,12 @@ struct operation
return
compute_op
(
private_detail_te_value
,
ctx
,
output
,
input
);
return
compute_op
(
private_detail_te_value
,
ctx
,
output
,
input
);
}
}
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
override
{
return
compute_op
(
private_detail_te_value
,
output
,
input
);
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
{
{
using
migraphx
::
operation_stream
::
operator
<<
;
using
migraphx
::
operation_stream
::
operator
<<
;
...
@@ -385,6 +543,22 @@ inline const ValueType& any_cast(const operation& x)
...
@@ -385,6 +543,22 @@ inline const ValueType& any_cast(const operation& x)
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
!
(
x
==
y
);
}
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
!
(
x
==
y
);
}
inline
bool
is_context_free
(
const
operation
&
op
)
{
return
op
.
is_context_free
();
}
template
<
class
T
>
bool
is_context_free
(
const
T
&
x
)
{
return
is_context_free_op
(
x
);
}
inline
bool
has_finalize
(
const
operation
&
op
)
{
return
op
.
has_finalize
();
}
template
<
class
T
>
bool
has_finalize
(
const
T
&
x
)
{
return
has_finalize_op
(
x
);
}
#endif
#endif
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/operators.hpp
View file @
087c205e
This diff is collapsed.
Click to expand it.
src/include/migraphx/par_dfor.hpp
0 → 100644
View file @
087c205e
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
#include <migraphx/par_for.hpp>
#include <migraphx/functional.hpp>
#include <array>
#include <numeric>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
...
Ts
>
auto
par_dfor
(
Ts
...
xs
)
{
return
[
=
](
auto
f
)
{
using
array_type
=
std
::
array
<
std
::
size_t
,
sizeof
...(
Ts
)
>
;
array_type
lens
=
{{
static_cast
<
std
::
size_t
>
(
xs
)...}};
auto
n
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
{});
const
std
::
size_t
min_grain
=
8
;
if
(
n
>
2
*
min_grain
)
{
array_type
strides
;
strides
.
fill
(
1
);
std
::
partial_sum
(
lens
.
rbegin
(),
lens
.
rend
()
-
1
,
strides
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
auto
size
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
par_for
(
size
,
min_grain
,
[
&
](
std
::
size_t
i
)
{
array_type
indices
;
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
lens
.
begin
(),
indices
.
begin
(),
[
&
](
size_t
stride
,
size_t
len
)
{
return
(
i
/
stride
)
%
len
;
});
migraphx
::
unpack
(
f
,
indices
);
});
}
else
{
dfor
(
xs
...)(
f
);
}
};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/par_for.hpp
0 → 100644
View file @
087c205e
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#include <thread>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cassert>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
joinable_thread
:
std
::
thread
{
template
<
class
...
Xs
>
joinable_thread
(
Xs
&&
...
xs
)
:
std
::
thread
(
std
::
forward
<
Xs
>
(
xs
)...)
// NOLINT
{
}
joinable_thread
&
operator
=
(
joinable_thread
&&
other
)
=
default
;
joinable_thread
(
joinable_thread
&&
other
)
=
default
;
~
joinable_thread
()
{
if
(
this
->
joinable
())
this
->
join
();
}
};
template
<
class
F
>
void
par_for_impl
(
std
::
size_t
n
,
std
::
size_t
threadsize
,
F
f
)
{
if
(
threadsize
<=
1
)
{
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
f
(
i
);
}
else
{
std
::
vector
<
joinable_thread
>
threads
(
threadsize
);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std
::
size_t
grainsize
=
std
::
ceil
(
static_cast
<
double
>
(
n
)
/
threads
.
size
());
std
::
size_t
work
=
0
;
std
::
generate
(
threads
.
begin
(),
threads
.
end
(),
[
=
,
&
work
]
{
auto
result
=
joinable_thread
([
=
]
{
std
::
size_t
start
=
work
;
std
::
size_t
last
=
std
::
min
(
n
,
work
+
grainsize
);
for
(
std
::
size_t
i
=
start
;
i
<
last
;
i
++
)
{
f
(
i
);
}
});
work
+=
grainsize
;
return
result
;
});
assert
(
work
>=
n
);
}
}
template
<
class
F
>
void
par_for
(
std
::
size_t
n
,
std
::
size_t
min_grain
,
F
f
)
{
const
auto
threadsize
=
std
::
min
<
std
::
size_t
>
(
std
::
thread
::
hardware_concurrency
(),
n
/
min_grain
);
par_for_impl
(
n
,
threadsize
,
f
);
}
template
<
class
F
>
void
par_for
(
std
::
size_t
n
,
F
f
)
{
const
int
min_grain
=
8
;
par_for
(
n
,
min_grain
,
f
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/pass.hpp
View file @
087c205e
...
@@ -105,7 +105,13 @@ struct pass
...
@@ -105,7 +105,13 @@ struct pass
void
apply
(
program
&
p
)
const
void
apply
(
program
&
p
)
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
apply
(
p
);
(
*
this
).
private_detail_te_get_handle
().
apply
(
p
);
}
friend
bool
is_shared
(
const
pass
&
private_detail_x
,
const
pass
&
private_detail_y
)
{
return
private_detail_x
.
private_detail_te_handle_mem_var
==
private_detail_y
.
private_detail_te_handle_mem_var
;
}
}
private:
private:
...
@@ -149,7 +155,7 @@ struct pass
...
@@ -149,7 +155,7 @@ struct pass
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
void
apply
(
program
&
p
)
const
override
{
return
private_detail_te_value
.
apply
(
p
);
}
void
apply
(
program
&
p
)
const
override
{
private_detail_te_value
.
apply
(
p
);
}
PrivateDetailTypeErasedT
private_detail_te_value
;
PrivateDetailTypeErasedT
private_detail_te_value
;
};
};
...
...
src/include/migraphx/program.hpp
View file @
087c205e
...
@@ -91,16 +91,22 @@ struct program
...
@@ -91,16 +91,22 @@ struct program
shape
get_shape
()
const
;
shape
get_shape
()
const
;
context
&
get_context
()
const
;
instruction_ref
validate
()
const
;
instruction_ref
validate
()
const
;
void
compile
(
const
target
&
t
,
tracer
trace
=
tracer
{});
void
compile
(
const
target
&
t
,
tracer
trace
=
tracer
{});
void
finalize
();
void
perf_report
(
std
::
ostream
&
os
,
std
::
size_t
n
,
parameter_map
params
)
const
;
void
perf_report
(
std
::
ostream
&
os
,
std
::
size_t
n
,
parameter_map
params
)
const
;
void
debug_print
()
const
;
void
debug_print
()
const
;
void
debug_print
(
instruction_ref
ins
)
const
;
void
debug_print
(
instruction_ref
ins
)
const
;
void
debug_print
(
const
std
::
vector
<
instruction_ref
>&
inss
)
const
;
void
debug_print
(
const
std
::
vector
<
instruction_ref
>&
inss
)
const
;
void
dry_run
(
parameter_map
params
)
const
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
...
...
src/include/migraphx/rewrite_rnn.hpp
0 → 100644
View file @
087c205e
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
program
;
/**
* Rewrite rnn to gemm and add.
*/
struct
rewrite_rnn
{
std
::
string
name
()
const
{
return
"rewrite_rnn"
;
}
void
apply
(
program
&
prog
)
const
;
private:
// for vanilla rnn operators
void
apply_vanilla_rnn
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
vanilla_rnn_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
w
,
instruction_ref
r
,
instruction_ref
bias
,
instruction_ref
ih
,
operation
&
actv_func
)
const
;
std
::
vector
<
operation
>
vanilla_rnn_actv_funcs
(
instruction_ref
ins
)
const
;
// for gru operators
void
apply_gru
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
gru_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
)
const
;
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
// for lstm operators
void
apply_lstm
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
lstm_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
,
const
operation
&
actv_func3
)
const
;
std
::
vector
<
operation
>
lstm_actv_funcs
(
instruction_ref
ins
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/shape.hpp
View file @
087c205e
...
@@ -35,22 +35,22 @@ struct shape
...
@@ -35,22 +35,22 @@ struct shape
m(uint64_type, uint64_t)
m(uint64_type, uint64_t)
// clang-format on
// clang-format on
#define MIGRAPHX_SHAPE_ENUM_TYPES(x, t) x,
#define MIGRAPHX_SHAPE_
GENERATE_
ENUM_TYPES(x, t) x,
enum
type_t
enum
type_t
{
{
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_ENUM_TYPES
)
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_
GENERATE_
ENUM_TYPES
)
};
};
#undef MIGRAPHX_SHAPE_ENUM_TYPES
#undef MIGRAPHX_SHAPE_
GENERATE_
ENUM_TYPES
template
<
class
T
,
class
=
void
>
template
<
class
T
,
class
=
void
>
struct
get_type
;
struct
get_type
;
#define MIGRAPHX_SHAPE_GET_TYPE(x, t)
\
#define MIGRAPHX_SHAPE_
GENERATE_
GET_TYPE(x, t) \
template <class T> \
template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
{ \
{ \
};
};
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_GET_TYPE
)
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_
GENERATE_
GET_TYPE
)
#undef MIGRAPHX_SHAPE_GET_TYPE
#undef MIGRAPHX_SHAPE_
GENERATE_
GET_TYPE
template
<
class
T
>
template
<
class
T
>
struct
get_type
<
const
T
>
:
get_type
<
T
>
struct
get_type
<
const
T
>
:
get_type
<
T
>
...
@@ -62,6 +62,19 @@ struct shape
...
@@ -62,6 +62,19 @@ struct shape
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
template
<
class
Range
>
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
{
}
template
<
class
Range1
,
class
Range2
>
shape
(
type_t
t
,
const
Range1
&
l
,
const
Range2
&
s
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()),
std
::
vector
<
std
::
size_t
>
(
s
.
begin
(),
s
.
end
()))
{
}
type_t
type
()
const
;
type_t
type
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
...
@@ -141,6 +154,8 @@ struct shape
...
@@ -141,6 +154,8 @@ struct shape
{
{
return
reinterpret_cast
<
const
T
*>
(
buffer
)
+
n
;
return
reinterpret_cast
<
const
T
*>
(
buffer
)
+
n
;
}
}
type_t
type_enum
()
const
{
return
get_type
<
T
>
{};
}
};
};
template
<
class
Visitor
>
template
<
class
Visitor
>
...
@@ -148,14 +163,22 @@ struct shape
...
@@ -148,14 +163,22 @@ struct shape
{
{
switch
(
this
->
type
())
switch
(
this
->
type
())
{
{
#define MIGRAPHX_SHAPE_VISITOR_CASE(x, t) \
#define MIGRAPHX_SHAPE_
GENERATE_
VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
case x: v(as<t>()); return;
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_VISITOR_CASE
)
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_
GENERATE_
VISITOR_CASE
)
#undef MIGRAPHX_SHAPE_VISITOR_CASE
#undef MIGRAPHX_SHAPE_
GENERATE_
VISITOR_CASE
}
}
MIGRAPHX_THROW
(
"Unknown type"
);
MIGRAPHX_THROW
(
"Unknown type"
);
}
}
template
<
class
Visitor
>
static
void
visit_types
(
Visitor
v
)
{
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL(x, t) v(as<t>());
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
)
#undef MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
}
private:
private:
std
::
shared_ptr
<
const
shape_impl
>
impl
;
std
::
shared_ptr
<
const
shape_impl
>
impl
;
...
...
src/include/migraphx/target.hpp
View file @
087c205e
...
@@ -127,6 +127,12 @@ struct target
...
@@ -127,6 +127,12 @@ struct target
return
(
*
this
).
private_detail_te_get_handle
().
get_context
();
return
(
*
this
).
private_detail_te_get_handle
().
get_context
();
}
}
friend
bool
is_shared
(
const
target
&
private_detail_x
,
const
target
&
private_detail_y
)
{
return
private_detail_x
.
private_detail_te_handle_mem_var
==
private_detail_y
.
private_detail_te_handle_mem_var
;
}
private:
private:
struct
private_detail_te_handle_base_type
struct
private_detail_te_handle_base_type
{
{
...
...
src/include/migraphx/tensor_view.hpp
View file @
087c205e
...
@@ -124,6 +124,8 @@ struct tensor_view
...
@@ -124,6 +124,8 @@ struct tensor_view
return
m_data
+
this
->
size
();
return
m_data
+
this
->
size
();
}
}
std
::
vector
<
T
>
to_vector
()
const
{
return
std
::
vector
<
T
>
(
this
->
begin
(),
this
->
end
());
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
{
{
if
(
!
x
.
empty
())
if
(
!
x
.
empty
())
...
@@ -164,7 +166,7 @@ bool operator!=(const tensor_view<T>& x, const tensor_view<U>& y)
...
@@ -164,7 +166,7 @@ bool operator!=(const tensor_view<T>& x, const tensor_view<U>& y)
}
}
template
<
class
T
>
template
<
class
T
>
tensor_view
<
T
>
make_view
(
shape
s
,
T
*
data
)
tensor_view
<
T
>
make_view
(
const
shape
&
s
,
T
*
data
)
{
{
return
{
s
,
data
};
return
{
s
,
data
};
}
}
...
...
src/include/migraphx/type_name.hpp
View file @
087c205e
...
@@ -18,7 +18,7 @@ const std::string& get_type_name()
...
@@ -18,7 +18,7 @@ const std::string& get_type_name()
name
=
typeid
(
PrivateMigraphTypeNameProbe
).
name
();
name
=
typeid
(
PrivateMigraphTypeNameProbe
).
name
();
name
=
name
.
substr
(
7
);
name
=
name
.
substr
(
7
);
#else
#else
const
char
parameter_name
[]
=
"PrivateMigraphTypeNameProbe ="
;
const
char
parameter_name
[]
=
"PrivateMigraphTypeNameProbe ="
;
// NOLINT
name
=
__PRETTY_FUNCTION__
;
name
=
__PRETTY_FUNCTION__
;
...
...
src/instruction.cpp
View file @
087c205e
...
@@ -97,7 +97,7 @@ const std::vector<instruction_ref>& instruction::outputs() const { return output
...
@@ -97,7 +97,7 @@ const std::vector<instruction_ref>& instruction::outputs() const { return output
bool
operator
==
(
const
instruction
&
x
,
const
instruction
&
y
)
bool
operator
==
(
const
instruction
&
x
,
const
instruction
&
y
)
{
{
if
(
not
(
x
.
result
==
y
.
result
and
x
.
op
==
y
.
op
and
x
.
arguments
==
y
.
arguments
))
if
(
std
::
tie
(
x
.
result
,
x
.
op
,
x
.
arguments
)
!=
std
::
tie
(
y
.
result
,
y
.
op
,
y
.
arguments
))
return
false
;
return
false
;
if
(
x
.
name
()
==
"@literal"
)
if
(
x
.
name
()
==
"@literal"
)
return
x
.
lit
==
y
.
lit
;
return
x
.
lit
==
y
.
lit
;
...
@@ -162,25 +162,54 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
...
@@ -162,25 +162,54 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old
->
remove_output
(
*
this
);
old
->
remove_output
(
*
this
);
}
}
std
::
vector
<
shape
>
compute_shapes
(
const
std
::
vector
<
instruction_ref
>&
args
)
argument
instruction
::
eval
()
const
{
{
std
::
vector
<
shape
>
shapes
(
args
.
size
());
if
(
op
.
name
()
==
"@literal"
)
std
::
transform
(
{
args
.
begin
(),
args
.
end
(),
shapes
.
begin
(),
[](
instruction_ref
i
)
{
return
i
->
get_shape
();
});
return
this
->
get_literal
().
get_argument
();
return
shapes
;
}
if
(
is_context_free
(
op
))
{
std
::
vector
<
argument
>
args
;
for
(
auto
&&
arg
:
this
->
inputs
())
{
argument
a
=
arg
->
eval
();
if
(
a
.
empty
())
return
{};
args
.
push_back
(
a
);
}
return
op
.
compute
(
result
,
args
);
}
return
{};
}
}
instruction_ref
instruction
::
get_output_alias
(
instruction_ref
ins
)
void
instruction
::
finalize
(
context
&
ctx
)
{
{
auto
i
=
ins
->
get_operator
().
output_alias
(
compute_shapes
(
ins
->
inputs
()));
if
(
has_finalize
(
this
->
op
))
this
->
op
.
finalize
(
ctx
,
this
->
get_shape
(),
to_shapes
(
this
->
inputs
()));
}
instruction_ref
instruction
::
get_output_alias
(
instruction_ref
ins
,
bool
shallow
)
{
auto
i
=
ins
->
get_operator
().
output_alias
(
to_shapes
(
ins
->
inputs
()));
if
(
i
<
0
)
if
(
i
<
0
)
return
ins
;
return
ins
;
if
(
shallow
)
return
ins
->
inputs
().
at
(
i
);
return
get_output_alias
(
ins
->
inputs
().
at
(
i
));
return
get_output_alias
(
ins
->
inputs
().
at
(
i
));
}
}
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
shapes
.
begin
(),
[](
instruction_ref
i
)
{
return
i
->
get_shape
();
});
return
shapes
;
}
shape
compute_shape
(
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
args
)
shape
compute_shape
(
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
args
)
{
{
return
op
.
compute_shape
(
compute
_shapes
(
args
));
return
op
.
compute_shape
(
to
_shapes
(
args
));
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/onnx/mnist.cpp
View file @
087c205e
...
@@ -14,7 +14,10 @@
...
@@ -14,7 +14,10 @@
auto
reverse_int
(
unsigned
int
i
)
auto
reverse_int
(
unsigned
int
i
)
{
{
unsigned
char
c1
,
c2
,
c3
,
c4
;
unsigned
char
c1
;
unsigned
char
c2
;
unsigned
char
c3
;
unsigned
char
c4
;
c1
=
i
&
255u
;
c1
=
i
&
255u
;
c2
=
(
i
>>
8u
)
&
255u
;
c2
=
(
i
>>
8u
)
&
255u
;
c3
=
(
i
>>
16u
)
&
255u
;
c3
=
(
i
>>
16u
)
&
255u
;
...
@@ -32,7 +35,9 @@ read_mnist_images(const std::string& full_path, int& number_of_images, int& imag
...
@@ -32,7 +35,9 @@ read_mnist_images(const std::string& full_path, int& number_of_images, int& imag
if
(
file
.
is_open
())
if
(
file
.
is_open
())
{
{
int
magic_number
=
0
,
n_rows
=
0
,
n_cols
=
0
;
int
magic_number
=
0
;
int
n_rows
=
0
;
int
n_cols
=
0
;
file
.
read
(
reinterpret_cast
<
char
*>
(
&
magic_number
),
sizeof
(
magic_number
));
file
.
read
(
reinterpret_cast
<
char
*>
(
&
magic_number
),
sizeof
(
magic_number
));
magic_number
=
reverse_int
(
magic_number
);
magic_number
=
reverse_int
(
magic_number
);
...
...
Prev
1
2
3
4
5
6
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment