Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
087c205e
Commit
087c205e
authored
Mar 04, 2019
by
Paul
Browse files
Merge from develop
parents
a3a9e469
e15b8333
Changes
255
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
979 additions
and
83 deletions
+979
-83
src/include/migraphx/functional.hpp
src/include/migraphx/functional.hpp
+6
-0
src/include/migraphx/instruction.hpp
src/include/migraphx/instruction.hpp
+6
-1
src/include/migraphx/iterator_for.hpp
src/include/migraphx/iterator_for.hpp
+2
-2
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+2
-2
src/include/migraphx/make_shared_array.hpp
src/include/migraphx/make_shared_array.hpp
+1
-1
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+0
-1
src/include/migraphx/onnx.hpp
src/include/migraphx/onnx.hpp
+18
-0
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+184
-10
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+461
-42
src/include/migraphx/par_dfor.hpp
src/include/migraphx/par_dfor.hpp
+51
-0
src/include/migraphx/par_for.hpp
src/include/migraphx/par_for.hpp
+82
-0
src/include/migraphx/pass.hpp
src/include/migraphx/pass.hpp
+8
-2
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+6
-0
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+65
-0
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+32
-9
src/include/migraphx/target.hpp
src/include/migraphx/target.hpp
+6
-0
src/include/migraphx/tensor_view.hpp
src/include/migraphx/tensor_view.hpp
+3
-1
src/include/migraphx/type_name.hpp
src/include/migraphx/type_name.hpp
+1
-1
src/instruction.cpp
src/instruction.cpp
+38
-9
src/onnx/mnist.cpp
src/onnx/mnist.cpp
+7
-2
No files found.
src/include/migraphx/functional.hpp
View file @
087c205e
...
...
@@ -94,6 +94,12 @@ constexpr void each_args(F)
{
}
template
<
class
F
,
class
T
>
auto
unpack
(
F
f
,
T
&
x
)
{
return
sequence_c
<
std
::
tuple_size
<
T
>
{}
>
([
&
](
auto
...
is
)
{
f
(
std
::
get
<
is
>
(
x
)...);
});
}
/// Implements a fix-point combinator
template
<
class
R
,
class
F
>
detail
::
fix_f
<
R
,
F
>
fix
(
F
f
)
...
...
src/include/migraphx/instruction.hpp
View file @
087c205e
...
...
@@ -14,6 +14,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
shape
compute_shape
(
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
args
);
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
instruction_ref
>&
args
);
struct
instruction
{
...
...
@@ -71,7 +72,11 @@ struct instruction
static
void
replace
(
instruction_ref
ins
,
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
);
static
instruction_ref
get_output_alias
(
instruction_ref
ins
);
argument
eval
()
const
;
void
finalize
(
context
&
ctx
);
static
instruction_ref
get_output_alias
(
instruction_ref
ins
,
bool
shallow
=
false
);
private:
// internal
...
...
src/include/migraphx/iterator_for.hpp
View file @
087c205e
...
...
@@ -17,9 +17,9 @@ struct iterator_for_range
struct
iterator
{
base_iterator
i
;
base_iterator
operator
*
()
{
return
i
;
}
base_iterator
operator
*
()
const
{
return
i
;
}
base_iterator
operator
++
()
{
return
++
i
;
}
bool
operator
!=
(
const
iterator
&
rhs
)
{
return
i
!=
rhs
.
i
;
}
bool
operator
!=
(
const
iterator
&
rhs
)
const
{
return
i
!=
rhs
.
i
;
}
};
iterator
begin
()
...
...
src/include/migraphx/literal.hpp
View file @
087c205e
...
...
@@ -22,8 +22,8 @@ struct literal : raw_data<literal>
{
literal
()
{}
template
<
class
U
,
class
T
=
deduce
<
U
>
>
literal
(
U
x
)
:
buffer
(
make_shared_array
<
char
>
(
sizeof
(
T
))),
m_shape
(
s
hape
::
get_type
<
T
>
{}
)
template
<
class
U
,
class
T
=
deduce
<
U
>
,
shape
::
type_t
ShapeType
=
shape
::
get_type
<
T
>
{}
>
literal
(
U
x
)
:
buffer
(
make_shared_array
<
char
>
(
sizeof
(
T
))),
m_shape
(
S
hape
Type
)
{
static_assert
(
std
::
is_trivially_copyable
<
T
>
{},
"Literals can only be trivial types"
);
*
(
reinterpret_cast
<
T
*>
(
buffer
.
get
()))
=
x
;
...
...
src/include/migraphx/make_shared_array.hpp
View file @
087c205e
...
...
@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
template
<
typename
T
>
std
::
shared_ptr
<
T
>
make_shared_array
(
size_t
size
)
{
return
std
::
shared_ptr
<
T
>
(
new
T
[
size
],
std
::
default_delete
<
T
[]
>
());
return
std
::
shared_ptr
<
T
>
(
new
T
[
size
],
std
::
default_delete
<
T
[]
>
());
// NOLINT
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/matcher.hpp
View file @
087c205e
...
...
@@ -214,7 +214,6 @@ void find_matches(program& p, Ms&&... ms)
bool
match
=
false
;
each_args
(
[
&
](
auto
&&
m
)
{
// cppcheck-suppress knownConditionTrueFalse
if
(
match
)
return
;
auto
r
=
match_instruction
(
p
,
ins
,
m
.
matcher
());
...
...
src/include/migraphx/onnx.hpp
View file @
087c205e
...
...
@@ -7,6 +7,24 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
unknown
{
std
::
string
op
;
std
::
string
name
()
const
{
return
"unknown:"
+
op
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
{
if
(
input
.
empty
())
return
{};
else
return
input
.
front
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
os
<<
x
.
name
();
return
os
;
}
};
/// Create a program from an onnx file
program
parse_onnx
(
const
std
::
string
&
name
);
...
...
src/include/migraphx/operation.hpp
View file @
087c205e
...
...
@@ -7,17 +7,17 @@
#include <memory>
#include <type_traits>
#include <utility>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
context
;
#ifdef DOXYGEN
/// The operation interface represents an action an instruction will perform. All
...
...
@@ -26,6 +26,8 @@ struct operation
{
/// A unique name identifying the operation
std
::
string
name
()
const
;
/// An optional method that can be used to finalize the operator before running
void
finalize
(
context
&
ctx
);
/// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an
/// exception.
...
...
@@ -53,6 +55,11 @@ struct operation
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
};
/// Returns true if operation does not require a context to run compute
bool
is_context_free
(
const
operation
&
x
);
/// Returns true if the operation has a finalize method
bool
has_finalize
(
const
operation
&
x
);
#else
namespace
operation_stream
{
...
...
@@ -89,7 +96,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
}
// namespace operation_equal
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
...
...
@@ -99,6 +106,14 @@ auto compute_op(rank<1>,
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
context
&
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
...
...
@@ -110,7 +125,53 @@ template <class T>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
return
compute_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
return
compute_op
(
rank
<
2
>
{},
x
,
ctx
,
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
std
::
declval
<
context
&>
()),
output_shape
,
input
))
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable without a context: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
return
compute_op
(
rank
<
2
>
{},
x
,
output_shape
,
input
);
}
template
<
class
T
>
auto
is_context_free_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
),
std
::
true_type
{});
template
<
class
T
>
auto
is_context_free_op
(
rank
<
0
>
,
const
T
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
->
std
::
false_type
;
template
<
class
T
>
auto
is_context_free_op
(
const
T
&
x
)
->
decltype
(
is_context_free_op
(
rank
<
1
>
{},
x
,
std
::
declval
<
const
shape
&>
(),
std
::
declval
<
std
::
vector
<
argument
>>
()))
{
return
{};
}
template
<
class
T
>
...
...
@@ -132,15 +193,57 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
return
output_alias_op
(
rank
<
1
>
{},
x
,
shapes
);
}
template
<
class
T
>
auto
finalize_op
(
rank
<
1
>
,
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
->
decltype
(
x
.
finalize
(
auto_any_cast
(
ctx
),
output_shape
,
input
),
void
())
{
x
.
finalize
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
template
<
class
T
>
void
finalize_op
(
rank
<
0
>
,
T
&
,
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
{
}
template
<
class
T
>
void
finalize_op
(
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
{
finalize_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
}
template
<
class
T
>
auto
has_finalize_op
(
rank
<
1
>
,
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
->
decltype
(
x
.
finalize
(
auto_any_cast
(
ctx
),
output_shape
,
input
),
std
::
true_type
{});
template
<
class
T
>
auto
has_finalize_op
(
rank
<
0
>
,
T
&
,
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
->
std
::
false_type
;
template
<
class
T
>
auto
has_finalize_op
(
const
T
&
)
->
decltype
(
has_finalize_op
(
rank
<
1
>
{},
std
::
declval
<
T
&>
(),
std
::
declval
<
context
&>
(),
std
::
declval
<
const
shape
&>
(),
std
::
declval
<
std
::
vector
<
shape
>>
()))
{
return
{};
}
/*
* Type-erased interface for:
*
* struct operation
* {
* std::string name() const;
* bool is_context_free() const;
* bool has_finalize() const;
* int output_alias(const std::vector<shape>& input) const;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ;
* };
...
...
@@ -210,12 +313,30 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
bool
is_context_free
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
is_context_free
();
}
bool
has_finalize
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
has_finalize
();
}
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
output_alias
(
input
);
}
void
finalize
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
finalize
(
ctx
,
output
,
input
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
...
@@ -228,6 +349,12 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
ctx
,
output
,
input
);
}
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
output
,
input
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
assert
(
op
.
private_detail_te_handle_mem_var
);
...
...
@@ -240,6 +367,12 @@ struct operation
return
x
.
private_detail_te_get_handle
().
operator
==
(
y
);
}
friend
bool
is_shared
(
const
operation
&
private_detail_x
,
const
operation
&
private_detail_y
)
{
return
private_detail_x
.
private_detail_te_handle_mem_var
==
private_detail_y
.
private_detail_te_handle_mem_var
;
}
private:
struct
private_detail_te_handle_base_type
{
...
...
@@ -247,13 +380,18 @@ struct operation
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
bool
is_context_free
()
const
=
0
;
virtual
bool
has_finalize
()
const
=
0
;
virtual
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
void
finalize
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
=
0
;
virtual
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
argument
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
bool
operator
==
(
const
operation
&
y
)
const
=
0
;
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
virtual
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
bool
operator
==
(
const
operation
&
y
)
const
=
0
;
};
template
<
typename
PrivateDetailTypeErasedT
>
...
...
@@ -286,12 +424,26 @@ struct operation
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
bool
is_context_free
()
const
override
{
return
is_context_free_op
(
private_detail_te_value
);
}
bool
has_finalize
()
const
override
{
return
has_finalize_op
(
private_detail_te_value
);
}
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
override
{
return
output_alias_op
(
private_detail_te_value
,
input
);
}
void
finalize
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
override
{
finalize_op
(
private_detail_te_value
,
ctx
,
output
,
input
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
override
{
...
...
@@ -306,6 +458,12 @@ struct operation
return
compute_op
(
private_detail_te_value
,
ctx
,
output
,
input
);
}
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
override
{
return
compute_op
(
private_detail_te_value
,
output
,
input
);
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
{
using
migraphx
::
operation_stream
::
operator
<<
;
...
...
@@ -385,6 +543,22 @@ inline const ValueType& any_cast(const operation& x)
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
!
(
x
==
y
);
}
inline
bool
is_context_free
(
const
operation
&
op
)
{
return
op
.
is_context_free
();
}
template
<
class
T
>
bool
is_context_free
(
const
T
&
x
)
{
return
is_context_free_op
(
x
);
}
inline
bool
has_finalize
(
const
operation
&
op
)
{
return
op
.
has_finalize
();
}
template
<
class
T
>
bool
has_finalize
(
const
T
&
x
)
{
return
has_finalize_op
(
x
);
}
#endif
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/operators.hpp
View file @
087c205e
This diff is collapsed.
Click to expand it.
src/include/migraphx/par_dfor.hpp
0 → 100644
View file @
087c205e
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
#include <migraphx/par_for.hpp>
#include <migraphx/functional.hpp>
#include <array>
#include <numeric>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
...
Ts
>
auto
par_dfor
(
Ts
...
xs
)
{
return
[
=
](
auto
f
)
{
using
array_type
=
std
::
array
<
std
::
size_t
,
sizeof
...(
Ts
)
>
;
array_type
lens
=
{{
static_cast
<
std
::
size_t
>
(
xs
)...}};
auto
n
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
{});
const
std
::
size_t
min_grain
=
8
;
if
(
n
>
2
*
min_grain
)
{
array_type
strides
;
strides
.
fill
(
1
);
std
::
partial_sum
(
lens
.
rbegin
(),
lens
.
rend
()
-
1
,
strides
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
auto
size
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
par_for
(
size
,
min_grain
,
[
&
](
std
::
size_t
i
)
{
array_type
indices
;
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
lens
.
begin
(),
indices
.
begin
(),
[
&
](
size_t
stride
,
size_t
len
)
{
return
(
i
/
stride
)
%
len
;
});
migraphx
::
unpack
(
f
,
indices
);
});
}
else
{
dfor
(
xs
...)(
f
);
}
};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/par_for.hpp
0 → 100644
View file @
087c205e
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#include <thread>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cassert>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
joinable_thread
:
std
::
thread
{
template
<
class
...
Xs
>
joinable_thread
(
Xs
&&
...
xs
)
:
std
::
thread
(
std
::
forward
<
Xs
>
(
xs
)...)
// NOLINT
{
}
joinable_thread
&
operator
=
(
joinable_thread
&&
other
)
=
default
;
joinable_thread
(
joinable_thread
&&
other
)
=
default
;
~
joinable_thread
()
{
if
(
this
->
joinable
())
this
->
join
();
}
};
template
<
class
F
>
void
par_for_impl
(
std
::
size_t
n
,
std
::
size_t
threadsize
,
F
f
)
{
if
(
threadsize
<=
1
)
{
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
f
(
i
);
}
else
{
std
::
vector
<
joinable_thread
>
threads
(
threadsize
);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std
::
size_t
grainsize
=
std
::
ceil
(
static_cast
<
double
>
(
n
)
/
threads
.
size
());
std
::
size_t
work
=
0
;
std
::
generate
(
threads
.
begin
(),
threads
.
end
(),
[
=
,
&
work
]
{
auto
result
=
joinable_thread
([
=
]
{
std
::
size_t
start
=
work
;
std
::
size_t
last
=
std
::
min
(
n
,
work
+
grainsize
);
for
(
std
::
size_t
i
=
start
;
i
<
last
;
i
++
)
{
f
(
i
);
}
});
work
+=
grainsize
;
return
result
;
});
assert
(
work
>=
n
);
}
}
template
<
class
F
>
void
par_for
(
std
::
size_t
n
,
std
::
size_t
min_grain
,
F
f
)
{
const
auto
threadsize
=
std
::
min
<
std
::
size_t
>
(
std
::
thread
::
hardware_concurrency
(),
n
/
min_grain
);
par_for_impl
(
n
,
threadsize
,
f
);
}
template
<
class
F
>
void
par_for
(
std
::
size_t
n
,
F
f
)
{
const
int
min_grain
=
8
;
par_for
(
n
,
min_grain
,
f
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/pass.hpp
View file @
087c205e
...
...
@@ -105,7 +105,13 @@ struct pass
void
apply
(
program
&
p
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
apply
(
p
);
(
*
this
).
private_detail_te_get_handle
().
apply
(
p
);
}
friend
bool
is_shared
(
const
pass
&
private_detail_x
,
const
pass
&
private_detail_y
)
{
return
private_detail_x
.
private_detail_te_handle_mem_var
==
private_detail_y
.
private_detail_te_handle_mem_var
;
}
private:
...
...
@@ -149,7 +155,7 @@ struct pass
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
void
apply
(
program
&
p
)
const
override
{
return
private_detail_te_value
.
apply
(
p
);
}
void
apply
(
program
&
p
)
const
override
{
private_detail_te_value
.
apply
(
p
);
}
PrivateDetailTypeErasedT
private_detail_te_value
;
};
...
...
src/include/migraphx/program.hpp
View file @
087c205e
...
...
@@ -91,16 +91,22 @@ struct program
shape
get_shape
()
const
;
context
&
get_context
()
const
;
instruction_ref
validate
()
const
;
void
compile
(
const
target
&
t
,
tracer
trace
=
tracer
{});
void
finalize
();
void
perf_report
(
std
::
ostream
&
os
,
std
::
size_t
n
,
parameter_map
params
)
const
;
void
debug_print
()
const
;
void
debug_print
(
instruction_ref
ins
)
const
;
void
debug_print
(
const
std
::
vector
<
instruction_ref
>&
inss
)
const
;
void
dry_run
(
parameter_map
params
)
const
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
...
...
src/include/migraphx/rewrite_rnn.hpp
0 → 100644
View file @
087c205e
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
program
;
/**
* Rewrite rnn to gemm and add.
*/
struct
rewrite_rnn
{
std
::
string
name
()
const
{
return
"rewrite_rnn"
;
}
void
apply
(
program
&
prog
)
const
;
private:
// for vanilla rnn operators
void
apply_vanilla_rnn
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
vanilla_rnn_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
w
,
instruction_ref
r
,
instruction_ref
bias
,
instruction_ref
ih
,
operation
&
actv_func
)
const
;
std
::
vector
<
operation
>
vanilla_rnn_actv_funcs
(
instruction_ref
ins
)
const
;
// for gru operators
void
apply_gru
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
gru_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
)
const
;
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
// for lstm operators
void
apply_lstm
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
lstm_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
,
const
operation
&
actv_func3
)
const
;
std
::
vector
<
operation
>
lstm_actv_funcs
(
instruction_ref
ins
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/shape.hpp
View file @
087c205e
...
...
@@ -35,22 +35,22 @@ struct shape
m(uint64_type, uint64_t)
// clang-format on
#define MIGRAPHX_SHAPE_ENUM_TYPES(x, t) x,
#define MIGRAPHX_SHAPE_
GENERATE_
ENUM_TYPES(x, t) x,
enum
type_t
{
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_ENUM_TYPES
)
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_
GENERATE_
ENUM_TYPES
)
};
#undef MIGRAPHX_SHAPE_ENUM_TYPES
#undef MIGRAPHX_SHAPE_
GENERATE_
ENUM_TYPES
template
<
class
T
,
class
=
void
>
struct
get_type
;
#define MIGRAPHX_SHAPE_GET_TYPE(x, t)
\
#define MIGRAPHX_SHAPE_
GENERATE_
GET_TYPE(x, t) \
template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
{ \
};
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_GET_TYPE
)
#undef MIGRAPHX_SHAPE_GET_TYPE
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_
GENERATE_
GET_TYPE
)
#undef MIGRAPHX_SHAPE_
GENERATE_
GET_TYPE
template
<
class
T
>
struct
get_type
<
const
T
>
:
get_type
<
T
>
...
...
@@ -62,6 +62,19 @@ struct shape
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
template
<
class
Range
>
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
{
}
template
<
class
Range1
,
class
Range2
>
shape
(
type_t
t
,
const
Range1
&
l
,
const
Range2
&
s
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()),
std
::
vector
<
std
::
size_t
>
(
s
.
begin
(),
s
.
end
()))
{
}
type_t
type
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
...
...
@@ -141,6 +154,8 @@ struct shape
{
return
reinterpret_cast
<
const
T
*>
(
buffer
)
+
n
;
}
type_t
type_enum
()
const
{
return
get_type
<
T
>
{};
}
};
template
<
class
Visitor
>
...
...
@@ -148,14 +163,22 @@ struct shape
{
switch
(
this
->
type
())
{
#define MIGRAPHX_SHAPE_VISITOR_CASE(x, t) \
#define MIGRAPHX_SHAPE_
GENERATE_
VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_VISITOR_CASE
)
#undef MIGRAPHX_SHAPE_VISITOR_CASE
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_
GENERATE_
VISITOR_CASE
)
#undef MIGRAPHX_SHAPE_
GENERATE_
VISITOR_CASE
}
MIGRAPHX_THROW
(
"Unknown type"
);
}
template
<
class
Visitor
>
static
void
visit_types
(
Visitor
v
)
{
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL(x, t) v(as<t>());
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
)
#undef MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
}
private:
std
::
shared_ptr
<
const
shape_impl
>
impl
;
...
...
src/include/migraphx/target.hpp
View file @
087c205e
...
...
@@ -127,6 +127,12 @@ struct target
return
(
*
this
).
private_detail_te_get_handle
().
get_context
();
}
friend
bool
is_shared
(
const
target
&
private_detail_x
,
const
target
&
private_detail_y
)
{
return
private_detail_x
.
private_detail_te_handle_mem_var
==
private_detail_y
.
private_detail_te_handle_mem_var
;
}
private:
struct
private_detail_te_handle_base_type
{
...
...
src/include/migraphx/tensor_view.hpp
View file @
087c205e
...
...
@@ -124,6 +124,8 @@ struct tensor_view
return
m_data
+
this
->
size
();
}
std
::
vector
<
T
>
to_vector
()
const
{
return
std
::
vector
<
T
>
(
this
->
begin
(),
this
->
end
());
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
{
if
(
!
x
.
empty
())
...
...
@@ -164,7 +166,7 @@ bool operator!=(const tensor_view<T>& x, const tensor_view<U>& y)
}
template
<
class
T
>
tensor_view
<
T
>
make_view
(
shape
s
,
T
*
data
)
tensor_view
<
T
>
make_view
(
const
shape
&
s
,
T
*
data
)
{
return
{
s
,
data
};
}
...
...
src/include/migraphx/type_name.hpp
View file @
087c205e
...
...
@@ -18,7 +18,7 @@ const std::string& get_type_name()
name
=
typeid
(
PrivateMigraphTypeNameProbe
).
name
();
name
=
name
.
substr
(
7
);
#else
const
char
parameter_name
[]
=
"PrivateMigraphTypeNameProbe ="
;
const
char
parameter_name
[]
=
"PrivateMigraphTypeNameProbe ="
;
// NOLINT
name
=
__PRETTY_FUNCTION__
;
...
...
src/instruction.cpp
View file @
087c205e
...
...
@@ -97,7 +97,7 @@ const std::vector<instruction_ref>& instruction::outputs() const { return output
bool
operator
==
(
const
instruction
&
x
,
const
instruction
&
y
)
{
if
(
not
(
x
.
result
==
y
.
result
and
x
.
op
==
y
.
op
and
x
.
arguments
==
y
.
arguments
))
if
(
std
::
tie
(
x
.
result
,
x
.
op
,
x
.
arguments
)
!=
std
::
tie
(
y
.
result
,
y
.
op
,
y
.
arguments
))
return
false
;
if
(
x
.
name
()
==
"@literal"
)
return
x
.
lit
==
y
.
lit
;
...
...
@@ -162,25 +162,54 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old
->
remove_output
(
*
this
);
}
std
::
vector
<
shape
>
compute_shapes
(
const
std
::
vector
<
instruction_ref
>&
args
)
argument
instruction
::
eval
()
const
{
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
shapes
.
begin
(),
[](
instruction_ref
i
)
{
return
i
->
get_shape
();
});
return
shapes
;
if
(
op
.
name
()
==
"@literal"
)
{
return
this
->
get_literal
().
get_argument
();
}
if
(
is_context_free
(
op
))
{
std
::
vector
<
argument
>
args
;
for
(
auto
&&
arg
:
this
->
inputs
())
{
argument
a
=
arg
->
eval
();
if
(
a
.
empty
())
return
{};
args
.
push_back
(
a
);
}
return
op
.
compute
(
result
,
args
);
}
return
{};
}
instruction_ref
instruction
::
get_output_alias
(
instruction_ref
ins
)
void
instruction
::
finalize
(
context
&
ctx
)
{
auto
i
=
ins
->
get_operator
().
output_alias
(
compute_shapes
(
ins
->
inputs
()));
if
(
has_finalize
(
this
->
op
))
this
->
op
.
finalize
(
ctx
,
this
->
get_shape
(),
to_shapes
(
this
->
inputs
()));
}
instruction_ref
instruction
::
get_output_alias
(
instruction_ref
ins
,
bool
shallow
)
{
auto
i
=
ins
->
get_operator
().
output_alias
(
to_shapes
(
ins
->
inputs
()));
if
(
i
<
0
)
return
ins
;
if
(
shallow
)
return
ins
->
inputs
().
at
(
i
);
return
get_output_alias
(
ins
->
inputs
().
at
(
i
));
}
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
shapes
.
begin
(),
[](
instruction_ref
i
)
{
return
i
->
get_shape
();
});
return
shapes
;
}
shape
compute_shape
(
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
args
)
{
return
op
.
compute_shape
(
compute
_shapes
(
args
));
return
op
.
compute_shape
(
to
_shapes
(
args
));
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/onnx/mnist.cpp
View file @
087c205e
...
...
@@ -14,7 +14,10 @@
auto
reverse_int
(
unsigned
int
i
)
{
unsigned
char
c1
,
c2
,
c3
,
c4
;
unsigned
char
c1
;
unsigned
char
c2
;
unsigned
char
c3
;
unsigned
char
c4
;
c1
=
i
&
255u
;
c2
=
(
i
>>
8u
)
&
255u
;
c3
=
(
i
>>
16u
)
&
255u
;
...
...
@@ -32,7 +35,9 @@ read_mnist_images(const std::string& full_path, int& number_of_images, int& imag
if
(
file
.
is_open
())
{
int
magic_number
=
0
,
n_rows
=
0
,
n_cols
=
0
;
int
magic_number
=
0
;
int
n_rows
=
0
;
int
n_cols
=
0
;
file
.
read
(
reinterpret_cast
<
char
*>
(
&
magic_number
),
sizeof
(
magic_number
));
magic_number
=
reverse_int
(
magic_number
);
...
...
Prev
1
2
3
4
5
6
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment