Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a5c1c7f6
Unverified
Commit
a5c1c7f6
authored
Feb 10, 2019
by
Paul Fultz II
Committed by
GitHub
Feb 10, 2019
Browse files
Merge branch 'develop' into mem_color_ordering_fix
parents
462a4920
d516b099
Changes
303
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
798 additions
and
177 deletions
+798
-177
src/include/migraphx/erase.hpp
src/include/migraphx/erase.hpp
+4
-4
src/include/migraphx/errors.hpp
src/include/migraphx/errors.hpp
+5
-5
src/include/migraphx/fallthrough.hpp
src/include/migraphx/fallthrough.hpp
+6
-6
src/include/migraphx/float_equal.hpp
src/include/migraphx/float_equal.hpp
+6
-6
src/include/migraphx/functional.hpp
src/include/migraphx/functional.hpp
+10
-4
src/include/migraphx/fwd_conv_batchnorm_rewrite.hpp
src/include/migraphx/fwd_conv_batchnorm_rewrite.hpp
+4
-4
src/include/migraphx/generate.hpp
src/include/migraphx/generate.hpp
+7
-7
src/include/migraphx/half.hpp
src/include/migraphx/half.hpp
+4
-4
src/include/migraphx/instruction.hpp
src/include/migraphx/instruction.hpp
+10
-5
src/include/migraphx/instruction_ref.hpp
src/include/migraphx/instruction_ref.hpp
+4
-4
src/include/migraphx/iterator_for.hpp
src/include/migraphx/iterator_for.hpp
+6
-6
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+4
-4
src/include/migraphx/make_shared_array.hpp
src/include/migraphx/make_shared_array.hpp
+5
-5
src/include/migraphx/manage_ptr.hpp
src/include/migraphx/manage_ptr.hpp
+5
-5
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+12
-13
src/include/migraphx/memory_coloring.hpp
src/include/migraphx/memory_coloring.hpp
+4
-4
src/include/migraphx/onnx.hpp
src/include/migraphx/onnx.hpp
+22
-4
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+189
-15
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+440
-72
src/include/migraphx/par_dfor.hpp
src/include/migraphx/par_dfor.hpp
+51
-0
No files found.
src/include/migraphx/erase.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_ERASE_HPP
#define MIGRAPH_GUARD_ERASE_HPP
#ifndef MIGRAPH
X
_GUARD_ERASE_HPP
#define MIGRAPH
X
_GUARD_ERASE_HPP
#include <algorithm>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
/**
* @brief Erase all elements from a container
...
...
@@ -33,7 +33,7 @@ auto erase_if(R&& r, P&& pred)
return
r
.
erase
(
std
::
remove_if
(
r
.
begin
(),
r
.
end
(),
pred
),
r
.
end
());
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/errors.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_ERRORS_HPP
#define MIGRAPH_GUARD_ERRORS_HPP
#ifndef MIGRAPH
X
_GUARD_ERRORS_HPP
#define MIGRAPH
X
_GUARD_ERRORS_HPP
#include <exception>
#include <stdexcept>
...
...
@@ -7,7 +7,7 @@
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
/// Represents exceptions that can be thrown by migraphxlib
struct
exception
:
std
::
runtime_error
...
...
@@ -43,10 +43,10 @@ inline std::string make_source_context(const std::string& file, int line)
/**
* @brief Throw an exception with context information
*/
#define MIGRAPH_THROW(...) \
#define MIGRAPH
X
_THROW(...) \
throw migraphx::make_exception(migraphx::make_source_context(__FILE__, __LINE__), __VA_ARGS__)
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/fallthrough.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_FALLTHROUGH_HPP
#define MIGRAPH_GUARD_FALLTHROUGH_HPP
#ifndef MIGRAPH
X
_GUARD_FALLTHROUGH_HPP
#define MIGRAPH
X
_GUARD_FALLTHROUGH_HPP
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
#ifdef __clang__
#define MIGRAPH_FALLTHROUGH [[clang::fallthrough]]
#define MIGRAPH
X
_FALLTHROUGH [[clang::fallthrough]]
#else
#define MIGRAPH_FALLTHROUGH
#define MIGRAPH
X
_FALLTHROUGH
#endif
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/float_equal.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_FLOAT_EQUAL_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_FLOAT_EQUAL_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_FLOAT_EQUAL_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_FLOAT_EQUAL_HPP
#include <algorithm>
#include <cmath>
...
...
@@ -12,14 +12,14 @@
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
...
Ts
>
using
common_type
=
typename
std
::
common_type
<
Ts
...
>::
type
;
struct
float_equal_fn
{
template
<
class
T
,
MIGRAPH_REQUIRES
(
std
::
is_floating_point
<
T
>{})
>
template
<
class
T
,
MIGRAPH
X
_REQUIRES
(
std
::
is_floating_point
<
T
>{})
>
static
bool
apply
(
T
x
,
T
y
)
{
return
std
::
isfinite
(
x
)
and
std
::
isfinite
(
y
)
and
...
...
@@ -27,7 +27,7 @@ struct float_equal_fn
std
::
nextafter
(
x
,
std
::
numeric_limits
<
T
>::
max
())
>=
y
;
}
template
<
class
T
,
MIGRAPH_REQUIRES
(
not
std
::
is_floating_point
<
T
>{})
>
template
<
class
T
,
MIGRAPH
X
_REQUIRES
(
not
std
::
is_floating_point
<
T
>{})
>
static
bool
apply
(
T
x
,
T
y
)
{
return
x
==
y
;
...
...
@@ -42,7 +42,7 @@ struct float_equal_fn
static
constexpr
float_equal_fn
float_equal
{};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/functional.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_RTGLIB_FUNCTIONAL_HPP
#define MIGRAPH_GUARD_RTGLIB_FUNCTIONAL_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_FUNCTIONAL_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_FUNCTIONAL_HPP
#include <utility>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
swallow
{
...
...
@@ -94,6 +94,12 @@ constexpr void each_args(F)
{
}
template
<
class
F
,
class
T
>
auto
unpack
(
F
f
,
T
&
x
)
{
return
sequence_c
<
std
::
tuple_size
<
T
>
{}
>
([
&
](
auto
...
is
)
{
f
(
std
::
get
<
is
>
(
x
)...);
});
}
/// Implements a fix-point combinator
template
<
class
R
,
class
F
>
detail
::
fix_f
<
R
,
F
>
fix
(
F
f
)
...
...
@@ -131,7 +137,7 @@ auto fold(F f)
return
[
=
](
auto
&&
...
xs
)
{
return
fold_impl
(
f
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
};
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/fwd_conv_batchnorm_rewrite.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#define MIGRAPH_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
program
;
...
...
@@ -19,7 +19,7 @@ struct fwd_conv_batchnorm_rewrite
void
apply
(
program
&
p
)
const
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/generate.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_GENERATE_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_GENERATE_HPP
#include <migraphx/argument.hpp>
#include <migraphx/literal.hpp>
...
...
@@ -8,9 +8,9 @@
#include <random>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
T
,
MIGRAPH_REQUIRES
(
is_floating_point
<
T
>{})
>
template
<
class
T
,
MIGRAPH
X
_REQUIRES
(
is_floating_point
<
T
>{})
>
constexpr
T
normalize
(
unsigned
long
z
)
{
if
(
z
==
0
)
...
...
@@ -22,7 +22,7 @@ constexpr T normalize(unsigned long z)
return
T
(
result
);
}
template
<
class
T
,
MIGRAPH_REQUIRES
(
is_signed
<
T
>{}
and
not
is_floating_point
<
T
>
{})
>
template
<
class
T
,
MIGRAPH
X
_REQUIRES
(
is_signed
<
T
>{}
and
not
is_floating_point
<
T
>
{})
>
constexpr
T
normalize
(
unsigned
long
z
)
{
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
();
...
...
@@ -30,7 +30,7 @@ constexpr T normalize(unsigned long z)
return
half_max
-
(
z
%
max
);
}
template
<
class
T
,
MIGRAPH_REQUIRES
(
not
is_signed
<
T
>{}
and
std
::
is_integral
<
T
>
{})
>
template
<
class
T
,
MIGRAPH
X
_REQUIRES
(
not
is_signed
<
T
>{}
and
std
::
is_integral
<
T
>
{})
>
constexpr
T
normalize
(
unsigned
long
z
)
{
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
();
...
...
@@ -93,7 +93,7 @@ literal generate_literal(shape s, unsigned long seed = 0);
literal
abs
(
literal
l
);
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/half.hpp
View file @
a5c1c7f6
...
...
@@ -5,14 +5,14 @@
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_HALF_HPP
#define MIGRAPH_GUARD_RTGLIB_HALF_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_HALF_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_HALF_HPP
#include <half.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
using
half
=
half_float
::
half
;
...
...
@@ -33,7 +33,7 @@ struct deduce<half_float::detail::expr>
template
<
class
T
>
using
deduce
=
typename
detail
::
deduce
<
T
>::
type
;
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/instruction.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#include <migraphx/literal.hpp>
#include <migraphx/shape.hpp>
...
...
@@ -11,9 +11,10 @@
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
shape
compute_shape
(
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
args
);
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
instruction_ref
>&
args
);
struct
instruction
{
...
...
@@ -71,7 +72,11 @@ struct instruction
static
void
replace
(
instruction_ref
ins
,
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
);
static
instruction_ref
get_output_alias
(
instruction_ref
ins
);
argument
eval
()
const
;
void
finalize
(
context
&
ctx
);
static
instruction_ref
get_output_alias
(
instruction_ref
ins
,
bool
shallow
=
false
);
private:
// internal
...
...
@@ -90,7 +95,7 @@ struct instruction
std
::
vector
<
instruction_ref
>
arguments
;
literal
lit
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
namespace
std
{
...
...
src/include/migraphx/instruction_ref.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_INSTRUCTION_REF_HPP
#define MIGRAPH_GUARD_INSTRUCTION_REF_HPP
#ifndef MIGRAPH
X
_GUARD_INSTRUCTION_REF_HPP
#define MIGRAPH
X
_GUARD_INSTRUCTION_REF_HPP
#include <list>
#include <functional>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
instruction
;
using
instruction_ref
=
std
::
list
<
instruction
>::
iterator
;
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/iterator_for.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP
#define MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_ITERATOR_FOR_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_ITERATOR_FOR_HPP
#include <cassert>
#include <type_traits>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
T
>
struct
iterator_for_range
...
...
@@ -17,9 +17,9 @@ struct iterator_for_range
struct
iterator
{
base_iterator
i
;
base_iterator
operator
*
()
{
return
i
;
}
base_iterator
operator
*
()
const
{
return
i
;
}
base_iterator
operator
++
()
{
return
++
i
;
}
bool
operator
!=
(
const
iterator
&
rhs
)
{
return
i
!=
rhs
.
i
;
}
bool
operator
!=
(
const
iterator
&
rhs
)
const
{
return
i
!=
rhs
.
i
;
}
};
iterator
begin
()
...
...
@@ -39,7 +39,7 @@ iterator_for_range<T> iterator_for(T& x)
return
{
&
x
};
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/literal.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_LITERAL_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_LITERAL_HPP
#include <migraphx/shape.hpp>
#include <migraphx/shape_for_each.hpp>
...
...
@@ -12,7 +12,7 @@
#include <memory>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
/**
* @brief Represents a raw literal
...
...
@@ -124,7 +124,7 @@ literal transform(literal l1, literal l2, F f)
return
result
;
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/make_shared_array.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_MAKE_SHARED_ARRAY_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_MAKE_SHARED_ARRAY_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_MAKE_SHARED_ARRAY_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_MAKE_SHARED_ARRAY_HPP
#include <memory>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
typename
T
>
std
::
shared_ptr
<
T
>
make_shared_array
(
size_t
size
)
{
return
std
::
shared_ptr
<
T
>
(
new
T
[
size
],
std
::
default_delete
<
T
[]
>
());
return
std
::
shared_ptr
<
T
>
(
new
T
[
size
],
std
::
default_delete
<
T
[]
>
());
// NOLINT
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/manage_ptr.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPH_MANAGE_PTR_HPP
#define MIGRAPH_GUARD_MIGRAPH_MANAGE_PTR_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPH
X
_MANAGE_PTR_HPP
#define MIGRAPH
X
_GUARD_MIGRAPH
X
_MANAGE_PTR_HPP
#include <memory>
#include <type_traits>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
F
,
F
f
>
// NOLINT
struct
manage_deleter
...
...
@@ -51,10 +51,10 @@ shared<T> share(T p)
return
shared
<
T
>
{
std
::
move
(
p
)};
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#define MIGRAPH_MANAGE_PTR(T, F) \
#define MIGRAPH
X
_MANAGE_PTR(T, F) \
migraphx::manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
#endif
src/include/migraphx/matcher.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_RTGLIB_MATCHER_HPP
#define MIGRAPH_GUARD_RTGLIB_MATCHER_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_MATCHER_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_MATCHER_HPP
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
...
...
@@ -10,7 +10,7 @@
#include <unordered_map>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
namespace
match
{
...
...
@@ -169,7 +169,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
}
/// This macro takes care of the boilerplate for defining a matcher
#define MIGRAPH_BASIC_MATCHER(name, ...)
\
#define MIGRAPH
X
_BASIC_MATCHER(name, ...) \
struct name##_m \
{ \
instruction_ref match(__VA_ARGS__) const; \
...
...
@@ -178,7 +178,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
inline instruction_ref name##_m::match(__VA_ARGS__) const
/// This macro takes care of the boilerplate for defining a predicate matcher
#define MIGRAPH_PRED_MATCHER(name, ...)
\
#define MIGRAPH
X
_PRED_MATCHER(name, ...) \
struct name##_m \
{ \
bool operator()(__VA_ARGS__) const; \
...
...
@@ -214,7 +214,6 @@ void find_matches(program& p, Ms&&... ms)
bool
match
=
false
;
each_args
(
[
&
](
auto
&&
m
)
{
// cppcheck-suppress knownConditionTrueFalse
if
(
match
)
return
;
auto
r
=
match_instruction
(
p
,
ins
,
m
.
matcher
());
...
...
@@ -266,22 +265,22 @@ auto any_of(Ts... ms)
});
}
MIGRAPH_PRED_MATCHER
(
any
,
instruction_ref
)
{
return
true
;
}
MIGRAPH_PRED_MATCHER
(
none
,
instruction_ref
)
{
return
false
;
}
MIGRAPH_PRED_MATCHER
(
standard_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
standard
();
}
MIGRAPH_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
MIGRAPH
X
_PRED_MATCHER
(
any
,
instruction_ref
)
{
return
true
;
}
MIGRAPH
X
_PRED_MATCHER
(
none
,
instruction_ref
)
{
return
false
;
}
MIGRAPH
X
_PRED_MATCHER
(
standard_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
standard
();
}
MIGRAPH
X
_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
broadcasted
();
}
MIGRAPH_BASIC_MATCHER
(
output
,
matcher_context
&
ctx
,
instruction_ref
ins
)
MIGRAPH
X
_BASIC_MATCHER
(
output
,
matcher_context
&
ctx
,
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
==
1
)
return
ins
->
outputs
().
front
();
return
ctx
.
not_found
();
}
MIGRAPH_BASIC_MATCHER
(
used_once
,
matcher_context
&
ctx
,
instruction_ref
ins
)
MIGRAPH
X
_BASIC_MATCHER
(
used_once
,
matcher_context
&
ctx
,
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
==
1
)
return
ins
;
...
...
@@ -340,7 +339,7 @@ inline auto either_arg(std::size_t i, std::size_t j)
}
}
// namespace match
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/memory_coloring.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_HPP
#define MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_MEMORY_COLORING_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_MEMORY_COLORING_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
program
;
/**
...
...
@@ -20,7 +20,7 @@ struct memory_coloring
void
apply
(
program
&
p
)
const
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/onnx.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_ONNX_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_ONNX_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_ONNX_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_ONNX_HPP
#include <migraphx/program.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
unknown
{
std
::
string
op
;
std
::
string
name
()
const
{
return
"unknown:"
+
op
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
{
if
(
input
.
empty
())
return
{};
else
return
input
.
front
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
os
<<
x
.
name
();
return
os
;
}
};
/// Create a program from an onnx file
program
parse_onnx
(
const
std
::
string
&
name
);
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operation.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_OPERAND_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_OPERAND_HPP
#include <cassert>
#include <string>
...
...
@@ -7,16 +7,16 @@
#include <memory>
#include <type_traits>
#include <utility>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
context
;
#ifdef DOXYGEN
...
...
@@ -26,6 +26,8 @@ struct operation
{
/// A unique name identifying the operation
std
::
string
name
()
const
;
/// An optional method that can be used to finalize the operator before running
void
finalize
(
context
&
ctx
);
/// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an
/// exception.
...
...
@@ -53,6 +55,11 @@ struct operation
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
};
/// Returns true if operation does not require a context to run compute
bool
is_context_free
(
const
operation
&
x
);
/// Returns true if the operation has a finalize method
bool
has_finalize
(
const
operation
&
x
);
#else
namespace
operation_stream
{
...
...
@@ -89,7 +96,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
}
// namespace operation_equal
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
...
...
@@ -99,18 +106,72 @@ auto compute_op(rank<1>,
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
context
&
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
std
::
string
name
=
x
.
name
();
MIGRAPH_THROW
(
"Not computable: "
+
name
);
MIGRAPH
X
_THROW
(
"Not computable: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
return
compute_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
return
compute_op
(
rank
<
2
>
{},
x
,
ctx
,
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
std
::
declval
<
context
&>
()),
output_shape
,
input
))
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable without a context: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
return
compute_op
(
rank
<
2
>
{},
x
,
output_shape
,
input
);
}
template
<
class
T
>
auto
is_context_free_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
),
std
::
true_type
{});
template
<
class
T
>
auto
is_context_free_op
(
rank
<
0
>
,
const
T
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
->
std
::
false_type
;
template
<
class
T
>
auto
is_context_free_op
(
const
T
&
x
)
->
decltype
(
is_context_free_op
(
rank
<
1
>
{},
x
,
std
::
declval
<
const
shape
&>
(),
std
::
declval
<
std
::
vector
<
argument
>>
()))
{
return
{};
}
template
<
class
T
>
...
...
@@ -132,15 +193,57 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
return
output_alias_op
(
rank
<
1
>
{},
x
,
shapes
);
}
template
<
class
T
>
auto
finalize_op
(
rank
<
1
>
,
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
->
decltype
(
x
.
finalize
(
auto_any_cast
(
ctx
),
output_shape
,
input
),
void
())
{
x
.
finalize
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
template
<
class
T
>
void
finalize_op
(
rank
<
0
>
,
T
&
,
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
{
}
template
<
class
T
>
void
finalize_op
(
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
{
finalize_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
}
template
<
class
T
>
auto
has_finalize_op
(
rank
<
1
>
,
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
->
decltype
(
x
.
finalize
(
auto_any_cast
(
ctx
),
output_shape
,
input
),
std
::
true_type
{});
template
<
class
T
>
auto
has_finalize_op
(
rank
<
0
>
,
T
&
,
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
->
std
::
false_type
;
template
<
class
T
>
auto
has_finalize_op
(
const
T
&
)
->
decltype
(
has_finalize_op
(
rank
<
1
>
{},
std
::
declval
<
T
&>
(),
std
::
declval
<
context
&>
(),
std
::
declval
<
const
shape
&>
(),
std
::
declval
<
std
::
vector
<
shape
>>
()))
{
return
{};
}
/*
* Type-erased interface for:
*
* struct operation
* {
* std::string name() const;
* bool is_context_free() const;
* bool has_finalize() const;
* int output_alias(const std::vector<shape>& input) const;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ;
* };
...
...
@@ -210,12 +313,30 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
bool
is_context_free
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
is_context_free
();
}
bool
has_finalize
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
has_finalize
();
}
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
output_alias
(
input
);
}
void
finalize
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
finalize
(
ctx
,
output
,
input
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
...
@@ -228,6 +349,12 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
ctx
,
output
,
input
);
}
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
output
,
input
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
assert
(
op
.
private_detail_te_handle_mem_var
);
...
...
@@ -240,6 +367,12 @@ struct operation
return
x
.
private_detail_te_get_handle
().
operator
==
(
y
);
}
friend
bool
is_shared
(
const
operation
&
private_detail_x
,
const
operation
&
private_detail_y
)
{
return
private_detail_x
.
private_detail_te_handle_mem_var
==
private_detail_y
.
private_detail_te_handle_mem_var
;
}
private:
struct
private_detail_te_handle_base_type
{
...
...
@@ -247,13 +380,18 @@ struct operation
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
bool
is_context_free
()
const
=
0
;
virtual
bool
has_finalize
()
const
=
0
;
virtual
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
void
finalize
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
=
0
;
virtual
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
argument
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
bool
operator
==
(
const
operation
&
y
)
const
=
0
;
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
virtual
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
bool
operator
==
(
const
operation
&
y
)
const
=
0
;
};
template
<
typename
PrivateDetailTypeErasedT
>
...
...
@@ -286,12 +424,26 @@ struct operation
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
bool
is_context_free
()
const
override
{
return
is_context_free_op
(
private_detail_te_value
);
}
bool
has_finalize
()
const
override
{
return
has_finalize_op
(
private_detail_te_value
);
}
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
override
{
return
output_alias_op
(
private_detail_te_value
,
input
);
}
void
finalize
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
override
{
finalize_op
(
private_detail_te_value
,
ctx
,
output
,
input
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
override
{
...
...
@@ -306,6 +458,12 @@ struct operation
return
compute_op
(
private_detail_te_value
,
ctx
,
output
,
input
);
}
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
override
{
return
compute_op
(
private_detail_te_value
,
output
,
input
);
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
{
using
migraphx
::
operation_stream
::
operator
<<
;
...
...
@@ -385,9 +543,25 @@ inline const ValueType& any_cast(const operation& x)
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
!
(
x
==
y
);
}
inline
bool
is_context_free
(
const
operation
&
op
)
{
return
op
.
is_context_free
();
}
template
<
class
T
>
bool
is_context_free
(
const
T
&
x
)
{
return
is_context_free_op
(
x
);
}
inline
bool
has_finalize
(
const
operation
&
op
)
{
return
op
.
has_finalize
();
}
template
<
class
T
>
bool
has_finalize
(
const
T
&
x
)
{
return
has_finalize_op
(
x
);
}
#endif
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
a5c1c7f6
This diff is collapsed.
Click to expand it.
src/include/migraphx/par_dfor.hpp
0 → 100644
View file @
a5c1c7f6
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
#include <migraphx/par_for.hpp>
#include <migraphx/functional.hpp>
#include <array>
#include <numeric>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
...
Ts
>
auto
par_dfor
(
Ts
...
xs
)
{
return
[
=
](
auto
f
)
{
using
array_type
=
std
::
array
<
std
::
size_t
,
sizeof
...(
Ts
)
>
;
array_type
lens
=
{{
static_cast
<
std
::
size_t
>
(
xs
)...}};
auto
n
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
{});
const
std
::
size_t
min_grain
=
8
;
if
(
n
>
2
*
min_grain
)
{
array_type
strides
;
strides
.
fill
(
1
);
std
::
partial_sum
(
lens
.
rbegin
(),
lens
.
rend
()
-
1
,
strides
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
auto
size
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
par_for
(
size
,
min_grain
,
[
&
](
std
::
size_t
i
)
{
array_type
indices
;
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
lens
.
begin
(),
indices
.
begin
(),
[
&
](
size_t
stride
,
size_t
len
)
{
return
(
i
/
stride
)
%
len
;
});
migraphx
::
unpack
(
f
,
indices
);
});
}
else
{
dfor
(
xs
...)(
f
);
}
};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
Prev
1
2
3
4
5
6
7
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment